Skip to content

[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar#1774

Open
Rachmanino wants to merge 12 commits intotile-ai:mainfrom
Rachmanino:gemm
Open

[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar#1774
Rachmanino wants to merge 12 commits intotile-ai:mainfrom
Rachmanino:gemm

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Feb 2, 2026

  • 1sm gemm on B200
    • non persistent: ~1450T
    • persistent: ~1550T
  • Handle mbar as BufferLoad to avoid missing index

Thanks @Hamerlate for providing the dev machine.

Summary by CodeRabbit

  • New Features

    • Added two comprehensive GEMM kernel examples demonstrating TileLang JIT compilation with integrated performance benchmarking and PyTorch validation
    • Includes both non-persistent and persistent 1-SM kernel implementations with fully configurable tiling and staging parameters
  • Refactor

    • Enhanced internal buffer handling infrastructure for GEMM operations

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

📝 Walkthrough

Walkthrough

This PR refactors the mbar (memory barrier) representation throughout the codebase, converting from a BufferRegion-based system to a direct BufferLoad approach. It adds a new utility function MakeAccessPtrFromBufferLoad, updates TileLang APIs to support this new type, and introduces two new GEMM example kernels demonstrating the updated implementation.

Changes

Cohort / File(s) Summary
Core mbar type refactoring
src/op/gemm.h, src/op/gemm.cc, src/op/gemm_py.h, src/op/gemm_py.cc
Changed mbar representation from BufferRegion-based to direct tir::BufferLoad; removed mbarRegion_ field; updated reflection to expose mbar_ and cCoords_ publicly.
Utility function for mbar access
src/op/utils.h, src/op/utils.cc
Added new MakeAccessPtrFromBufferLoad() function that constructs row-major offsets and tvm_access_ptr calls from BufferLoad inputs, mirroring existing MakeAccessPtrFromRegion logic.
TileLang API updates
tilelang/language/gemm_op.py, tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_tcgen05.py, tilelang/language/builtin.py
Updated mbar handling: changed mbar() return type to tir.BufferLoad | None; added retrieve_ptr() usage for pointer conversion; updated tcgen05_mma_arrive() signature to accept Buffer | BufferLoad | PrimExpr with runtime conversion logic; replaced mbar == 0 checks with mbar is None.
New GEMM example kernels
examples/gemm_sm100/gemm_tcgen5mma_ws.py, examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py
Introduced two TileLang-based GEMM implementations: non-persistent 1-SM kernel (91 lines) and persistent 1-SM kernel with two epilogue stages (154 lines); both include kernel source printing, PyTorch validation, and latency/TFLOPS benchmarking.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • lucifer1004

Poem

🐰 With barriers now buffered and loaded with care,
We hopped through the code, refactoring rare!
From regions to loads, our kernels now glow,
Two GEMM examples bloom—persistent and lean flow! 🌱✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the two main changes: adding a 1SM GEMM example for Blackwell and fixing mbar handling, which are clearly supported by the raw summary.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Feb 2, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@Rachmanino Rachmanino marked this pull request as ready for review February 3, 2026 07:39
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/op/gemm_py.cc`:
- Around line 83-86: When handling the optional mbar argument in the arg-parsing
block, don't silently ignore non-BufferLoad values: if args.size() > 16 then
attempt the BufferLoadNode cast as currently done (check for BufferLoadNode and
set node->mbar_ via Downcast<BufferLoad>), but add an else branch that fails
fast (throw or LOG(FATAL)/CHECK) reporting that arg 16 was expected to be a
BufferLoad and include the actual argument's type/name (use whatever runtime
type introspection is available on the Expr/Node to include in the message) so
callers get a clear error instead of silently skipping mbar_.

In `@src/op/gemm_py.h`:
- Line 32: mbar_ is declared as a non-optional tir::BufferLoad but is only
conditionally assigned (when args.size() > 16 and args[16] is a BufferLoadNode),
causing a type-contract mismatch; change the field declaration from
tir::BufferLoad mbar_ to std::optional<tir::BufferLoad> mbar_, update the
parser/initializer (where args is inspected) to emplace/assign mbar_ only in the
conditional branch, and adjust any uses of mbar_ (check has_value() or use
value_or) so code and the Python bindings safely handle the absent case;
alternatively, if you prefer non-optional, ensure mbar_ is unconditionally
initialized in the same constructor code path and remove the "optional" comment.

In `@src/op/utils.cc`:
- Around line 95-122: The function MakeAccessPtrFromBufferLoad uses hard-coded
DataType::Int(32) for offset, stride and extent which can overflow for large
buffers; change all occurrences of make_const(DataType::Int(32), ...) and the
IntImm for rw_mask to use the buffer's index dtype (buf->index_dtype) instead:
initialize offset and stride with make_const(buf->index_dtype, 0/1), compute
offset/stride arithmetic with that dtype, set extent using
make_const(buf->index_dtype, 1), and construct the rw_mask as
IntImm(buf->index_dtype, rw_mask) when building acc_args; update references
inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
🧹 Nitpick comments (1)
src/op/gemm.cc (1)

84-88: Consider hard-failing on unexpected mbar argument type.

When args.size() > 16, a non-BufferLoad value is silently dropped. That can mask frontend mismatches. Prefer an ICHECK to fail early when the caller passes the wrong type.

Suggested tightening
 if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    } else {
-      node->mbar_ = std::nullopt;
-    }
+    ICHECK(args[16].as<BufferLoadNode>())
+        << "mbar must be provided as BufferLoad when present";
+    node->mbar_ = Downcast<BufferLoad>(args[16]);
 }

Comment on lines 83 to 86
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fail fast when mbar is present but not a BufferLoad.

Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer.

💡 Suggested fix
-  if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    }
-  }
+  if (args.size() > 16) {
+    ICHECK(args[16].as<BufferLoadNode>())
+        << "mbar must be provided as a BufferLoad when present";
+    node->mbar_ = Downcast<BufferLoad>(args[16]);
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
if (args.size() > 16) {
ICHECK(args[16].as<BufferLoadNode>())
<< "mbar must be provided as a BufferLoad when present";
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
🤖 Prompt for AI Agents
In `@src/op/gemm_py.cc` around lines 83 - 86, When handling the optional mbar
argument in the arg-parsing block, don't silently ignore non-BufferLoad values:
if args.size() > 16 then attempt the BufferLoadNode cast as currently done
(check for BufferLoadNode and set node->mbar_ via Downcast<BufferLoad>), but add
an else branch that fails fast (throw or LOG(FATAL)/CHECK) reporting that arg 16
was expected to be a BufferLoad and include the actual argument's type/name (use
whatever runtime type introspection is available on the Expr/Node to include in
the message) so callers get a clear error instead of silently skipping mbar_.

PrimExpr clearAccum_ = const_false();
BufferRegion mbarRegion_;
tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 1338


🏁 Script executed:

# Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 132


🏁 Script executed:

# Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2

Repository: tile-ai/tilelang

Length of output: 1511


🏁 Script executed:

# Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 286


🏁 Script executed:

# Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20

Repository: tile-ai/tilelang

Length of output: 114


🏁 Script executed:

# Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20

Repository: tile-ai/tilelang

Length of output: 501


🏁 Script executed:

# Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 81


🏁 Script executed:

cat -n src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 15164


Resolve "optional" mismatch for mbar_ (type-contract inconsistency).

mbar_ is declared as non-optional tir::BufferLoad, but the comment marks it as optional and initialization is conditional (line 85: only assigned if args.size() > 16 and args[16] is a BufferLoadNode). This creates a type-contract mismatch: the field can remain uninitialized while the type signature suggests it's always present.

Suggestion: Make this std::optional<tir::BufferLoad> to match the semantic intent, or guarantee initialization and remove the optional comment. Verify Python-side code handles undefined mbar_ safely.

🤖 Prompt for AI Agents
In `@src/op/gemm_py.h` at line 32, mbar_ is declared as a non-optional
tir::BufferLoad but is only conditionally assigned (when args.size() > 16 and
args[16] is a BufferLoadNode), causing a type-contract mismatch; change the
field declaration from tir::BufferLoad mbar_ to std::optional<tir::BufferLoad>
mbar_, update the parser/initializer (where args is inspected) to emplace/assign
mbar_ only in the conditional branch, and adjust any uses of mbar_ (check
has_value() or use value_or) so code and the Python bindings safely handle the
absent case; alternatively, if you prefer non-optional, ensure mbar_ is
unconditionally initialized in the same constructor code path and remove the
"optional" comment.

Comment on lines 95 to 133
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());

// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = make_const(DataType::Int(32), 0);
PrimExpr stride = make_const(DataType::Int(32), 1);

for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}

// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);

// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid int32 offset/stride to prevent overflow on large buffers.

DataType::Int(32) can overflow for large shapes and diverges from MakeAccessPtrFromRegion. Use the buffer index dtype for offset/stride/extent.

💡 Suggested fix
-  // Compute offset using row-major layout (iterate in reverse)
-  PrimExpr offset = make_const(DataType::Int(32), 0);
-  PrimExpr stride = make_const(DataType::Int(32), 1);
+  // Compute offset using row-major layout (iterate in reverse)
+  DataType idx_dtype = buf->shape[0].dtype();
+  PrimExpr offset = make_const(idx_dtype, 0);
+  PrimExpr stride = make_const(idx_dtype, 1);
@@
-  // Extent is 1 element for a single BufferLoad access
-  PrimExpr extent = make_const(DataType::Int(32), 1);
+  // Extent is 1 element for a single BufferLoad access
+  PrimExpr extent = make_const(idx_dtype, 1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = make_const(DataType::Int(32), 0);
PrimExpr stride = make_const(DataType::Int(32), 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
DataType idx_dtype = buf->shape[0].dtype();
PrimExpr offset = make_const(idx_dtype, 0);
PrimExpr stride = make_const(idx_dtype, 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(idx_dtype, 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
🤖 Prompt for AI Agents
In `@src/op/utils.cc` around lines 95 - 122, The function
MakeAccessPtrFromBufferLoad uses hard-coded DataType::Int(32) for offset, stride
and extent which can overflow for large buffers; change all occurrences of
make_const(DataType::Int(32), ...) and the IntImm for rw_mask to use the
buffer's index dtype (buf->index_dtype) instead: initialize offset and stride
with make_const(buf->index_dtype, 0/1), compute offset/stride arithmetic with
that dtype, set extent using make_const(buf->index_dtype, 1), and construct the
rw_mask as IntImm(buf->index_dtype, rw_mask) when building acc_args; update
references inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 23-47: Add an explicit precondition that M, N, and K are exact
multiples of block_M, block_N, and block_K respectively to prevent out-of-bounds
writes when the kernel (which computes k_iters via T.ceildiv and writes C_shared
-> C at by*block_M, bx*block_N) runs; insert an assertion near the top of the
function before the T.Kernel block (where k_iters, A/B/C and shared buffers are
set up) that checks M % block_M == 0, N % block_N == 0, and K % block_K == 0 and
fail early if not, mirroring the guards used in other GEMM kernels.

In `@tilelang/language/builtin.py`:
- Around line 812-814: The code references mbar_ptr in the tir.call_intrin call
but only sets mbar_ptr inside the isinstance(mbar, (tir.Buffer, BufferLoad))
branch, causing UnboundLocalError when a raw PrimExpr is passed; fix by ensuring
mbar_ptr is always defined: after the existing if-block set mbar_ptr = mbar for
the PrimExpr case (or otherwise convert the PrimExpr to the expected pointer
form) so that mbar_ptr is available before calling tir.call_intrin("void",
tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr); modify the logic around the
mbar/mbar_ptr handling in the function in tilelang/language/builtin.py to cover
both tir.Buffer/BufferLoad and PrimExpr inputs.

Comment on lines 23 to 38
k_iters = T.ceildiv(K, block_K)

A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[K, N], in_dtype]
C = T.empty((M, N), out_dtype)

with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
consumed = T.alloc_barrier([1] * num_stages)
tmem_full = T.alloc_barrier([1])

tx = T.get_thread_binding()

T.use_swizzle(8)

if tx < 32: # warp 0: issue tma
for k in T.serial(k_iters):
T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :])
T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find the file mentioned in the review
fd "gemm_tcgen5mma_ws.py" --type f

Repository: tile-ai/tilelang

Length of output: 100


🏁 Script executed:

# List the repository structure to understand the codebase layout
ls -la examples/ 2>/dev/null | head -20

Repository: tile-ai/tilelang

Length of output: 1244


🏁 Script executed:

# Search for other GEMM examples to understand the pattern
fd "gemm.*\.py" examples/ --type f | head -20

Repository: tile-ai/tilelang

Length of output: 1130


🏁 Script executed:

# Examine the file mentioned in the review
cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.py

Repository: tile-ai/tilelang

Length of output: 4292


🏁 Script executed:

# Also search for other gemm_sm100 files to get context
ls -la examples/gemm_sm100/

Repository: tile-ai/tilelang

Length of output: 434


🏁 Script executed:

# Search for other GEMM examples to understand how they handle non-divisible dimensions
grep -r "T.copy.*\[.*\*.*block_" examples/gemm* --include="*.py" | head -20

Repository: tile-ai/tilelang

Length of output: 2122


🏁 Script executed:

# Look at other SM100 GEMM variants to see if they have guards
cat -n examples/gemm_sm100/gemm_tcgen5mma.py

Repository: tile-ai/tilelang

Length of output: 3381


🏁 Script executed:

# Search for any guard patterns (assertions, checks) in GEMM examples
rg "assert.*%|assert.*divisible|if.*%" examples/gemm* -A 2 -B 2 | head -40

Repository: tile-ai/tilelang

Length of output: 3576


Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses T.ceildiv for grid dimensions, and line 68 copies C_shared (shape [block_M, block_N]) to C[by * block_M, bx * block_N] without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.

🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
 def gemm(
     A,
     B,
@@
 ):
     M, N, K = T.const("M, N, K")
+    # Kernel requires exact tile coverage to avoid OOB access at line 68.
+    assert (
+        M % block_M == 0
+        and N % block_N == 0
+        and K % block_K == 0
+    ), "M/N/K must be divisible by block_M/block_N/block_K"
🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 23 - 47, Add an
explicit precondition that M, N, and K are exact multiples of block_M, block_N,
and block_K respectively to prevent out-of-bounds writes when the kernel (which
computes k_iters via T.ceildiv and writes C_shared -> C at by*block_M,
bx*block_N) runs; insert an assertion near the top of the function before the
T.Kernel block (where k_iters, A/B/C and shared buffers are set up) that checks
M % block_M == 0, N % block_N == 0, and K % block_K == 0 and fail early if not,
mirroring the guards used in other GEMM kernels.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py`:
- Around line 31-33: Add explicit tiling precondition checks to prevent
out-of-bounds tile loads/stores: validate at runtime that M % block_M == 0 and N
% block_N == 0 to ensure full tiles for the slices used when loading/storing
(referencing m_blocks and n_blocks), and also assert that block_N %
store_block_N == 0 (to prevent OOB access into C_local when iterating
store_block_N across a block_N tile). Implement these as early
assertions/argument checks (or raise clear exceptions) before any tensor slicing
or the compute loop so callers supplying M, N, block_M, block_N, and
store_block_N cannot trigger OOB accesses.

Comment on lines +31 to +33
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
assert K % (2 * block_K) == 0 # for simplicity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp/repo && find . -name "gemm_tcgen5mma_ws_persistent.py" -type f

Repository: tile-ai/tilelang

Length of output: 119


🏁 Script executed:

head -150 examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py | cat -n

Repository: tile-ai/tilelang

Length of output: 7610


Add tiling precondition checks to prevent OOB loads/stores.

The kernel loads/stores full tiles without boundary masking. If M or N aren't multiples of block_M or block_N, the slices on lines 63 and 65 will read past array bounds. Similarly, on line 119, if store_block_N doesn't evenly divide block_N, the loop will access out-of-bounds memory in C_local. While the test case uses compatible parameters (8192 with block size 128/256, store_block_N 128), the kernel interface accepts these as runtime parameters, allowing unsafe calls.

💡 Suggested guardrails
-    assert K % (2 * block_K) == 0  # for simplicity
+    assert K % (2 * block_K) == 0, "K must be divisible by 2 * block_K"
+    assert M % block_M == 0 and N % block_N == 0, "M/N must be multiples of block_M/block_N"
+    assert store_block_N <= block_N and block_N % store_block_N == 0, "store_block_N must evenly divide block_N"
🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py` around lines 31 - 33,
Add explicit tiling precondition checks to prevent out-of-bounds tile
loads/stores: validate at runtime that M % block_M == 0 and N % block_N == 0 to
ensure full tiles for the slices used when loading/storing (referencing m_blocks
and n_blocks), and also assert that block_N % store_block_N == 0 (to prevent OOB
access into C_local when iterating store_block_N across a block_N tile).
Implement these as early assertions/argument checks (or raise clear exceptions)
before any tensor slicing or the compute loop so callers supplying M, N,
block_M, block_N, and store_block_N cannot trigger OOB accesses.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/gemm.cc (1)

90-91: ⚠️ Potential issue | 🟡 Minor

Potential out-of-bounds access on args[17] and args[18].

cCoords_ initialization accesses args[17] and args[18] unconditionally, but the preceding check only validates args.size() > 16. If the caller provides exactly 17 elements, this will cause an out-of-bounds access.

Consider adding a bounds check or documenting the minimum required args size:

Suggested fix
+  ICHECK(args.size() >= 19) << "Expected at least 19 arguments for Gemm, got " << args.size();
   node->cCoords_ = Array<PrimExpr>(
       {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
🧹 Nitpick comments (2)
src/op/gemm_py.cc (2)

83-86: Design note: Silent ignore is intentional for placeholder constants.

The Python side (gemm_op.py, lines 109-112) passes tir.const(0, dtype="int32") as a placeholder when mbar is None. The current behavior of silently ignoring non-BufferLoad values is intentional to handle this placeholder case correctly.

However, for debugging purposes, you might consider adding a debug log when a non-BufferLoad is encountered to help catch genuine call-site errors vs. expected placeholder values.

💡 Optional: Add debug logging
   if (args.size() > 16) {
     if (const auto *load = args[16].as<BufferLoadNode>()) {
       node->mbar_ = Downcast<BufferLoad>(args[16]);
+    } else {
+      DLOG(INFO) << "mbar arg is not BufferLoad (placeholder or error): "
+                 << args[16]->GetTypeKey();
     }
   }

88-89: Consider adding bounds check for cCoords_ access.

Lines 88-89 access args[17] and args[18] unconditionally after only checking args.size() > 16. While the Python call site currently always passes 19 arguments, adding a bounds check would make this more robust against future API changes.

💡 Suggested defensive check
+  ICHECK(args.size() >= 19) << "Expected at least 19 arguments for GemmPy, got " << args.size();
   node->cCoords_ = Array<PrimExpr>(
       {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/gemm_op.py (1)

99-113: ⚠️ Potential issue | 🔴 Critical

Convert mbar directly to tir.BufferLoad; to_buffer_region() returns BufferRegion which fails the C++ BufferLoadNode check.

When mbar is converted via to_buffer_region(mbar, access_type="rw") without extents, it returns a BufferRegion. The C++ side at line 84 of src/op/gemm.cc checks if (const auto *load = args[16].as<BufferLoadNode>()). Since BufferRegion is not a BufferLoadNode, this check fails and mbar_ is set to std::nullopt, disabling barrier synchronization. This later causes a crash when line 493 tries to call MakeAccessPtrFromBufferLoad(mbar_.value()).

🔧 Suggested fix
    if mbar is not None:
        assert isinstance(mbar, (tir.Buffer, tir.BufferLoad)), (
            f"mbar for tcgen5mma must be a tir.Buffer or tir.BufferLoad, but got {type(mbar)}"
        )
-        mbar = to_buffer_region(mbar, access_type="rw")
+        if isinstance(mbar, tir.Buffer):
+            mbar = tir.BufferLoad(mbar, [0])
🤖 Fix all issues with AI agents
In `@src/op/gemm.cc`:
- Around line 83-88: The code silently sets node->mbar_ to nullopt when args[16]
exists but is not a BufferLoadNode; instead, validate the type and fail fast:
when args.size() > 16 and args[16] is non-null, check that
args[16].as<BufferLoadNode>() is non-null and if it isn't, raise an explicit
error (e.g., TVM_PANIC/LOG(FATAL)/ICHECK/throw) describing the unexpected arg
type for mbar_; otherwise keep the existing Downcast<BufferLoad>(args[16])
assignment to node->mbar_.

Comment on lines 83 to 86
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
} else {
node->mbar_ = std::nullopt;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate mbar arg type instead of silently dropping it.

If args[16] is non-null but not a BufferLoad, this currently becomes nullopt and can hide frontend regressions (e.g., still passing a region). Consider an explicit check to fail fast.

🔧 Suggested guard
   if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    } else {
-      node->mbar_ = std::nullopt;
-    }
+    const auto *load = args[16].as<BufferLoadNode>();
+    ICHECK(load || !args[16].defined())
+        << "mbar must be a BufferLoad or null when provided";
+    node->mbar_ =
+        load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
+             : std::nullopt;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
} else {
node->mbar_ = std::nullopt;
}
if (args.size() > 16) {
const auto *load = args[16].as<BufferLoadNode>();
ICHECK(load || !args[16].defined())
<< "mbar must be a BufferLoad or null when provided";
node->mbar_ =
load ? std::optional<BufferLoad>(Downcast<BufferLoad>(args[16]))
: std::nullopt;
}
🤖 Prompt for AI Agents
In `@src/op/gemm.cc` around lines 83 - 88, The code silently sets node->mbar_ to
nullopt when args[16] exists but is not a BufferLoadNode; instead, validate the
type and fail fast: when args.size() > 16 and args[16] is non-null, check that
args[16].as<BufferLoadNode>() is non-null and if it isn't, raise an explicit
error (e.g., TVM_PANIC/LOG(FATAL)/ICHECK/throw) describing the unexpected arg
type for mbar_; otherwise keep the existing Downcast<BufferLoad>(args[16])
assignment to node->mbar_.

@Rachmanino
Copy link
Collaborator Author

@LeiWang1999 fixed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/gemm_op.py (1)

99-112: ⚠️ Potential issue | 🔴 Critical

mbar is incorrectly converted to BufferRegion but C++ expects BufferLoadNode.

Line 103 converts mbar to a tir.BufferRegion via to_buffer_region(mbar, access_type="rw"). However, line 112 passes this BufferRegion directly as mbar_arg to call_intrin. The C++ side (gemm_py.cc lines 84–86) checks args[16].as<BufferLoadNode>(), which will fail because BufferRegion is not a BufferLoadNode. This means mbar_ will remain uninitialized even when a valid barrier is provided.

For consistency with A_arg, B_arg, and C_arg (which use buffer_region_to_tile_region to wrap in tl.region calls), either:

  • Apply buffer_region_to_tile_region to mbar as well, or
  • Pass the original mbar without converting to BufferRegion on line 103, or
  • Extract the BufferLoad from the BufferRegion before passing it
🧹 Nitpick comments (1)
examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py (1)

53-66: Tile-ID to (bx, by) mapping is duplicated three times.

The tile_id → (bx, by) computation (lines 55–57, 70–72, 104–106) is identical across all three warp branches. Consider extracting it into a helper or computing it once before the branch to reduce duplication and the risk of divergence if the mapping logic changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants