[Refactor] Relocate layout transformation of ptx_stmatrix #1689
[Refactor] Relocate layout transformation of ptx_stmatrix #1689LeiWang1999 merged 13 commits intotile-ai:mainfrom
ptx_stmatrix #1689Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdded full-range checks and early fallback in LDSM copy lowering; removed shared-buffer remapping in copy lowering; implemented layout-aware rewrites for Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant IR as IR Call
participant Lower as LowerTileOpPass
participant BMap as buffer_map_/var_remap_
participant Layout as Layout API
participant Rewriter as Call Rewriter
IR->>Lower: visit `tvm_access_ptr`/`address_of`/`ptx_*` call
Lower->>BMap: resolve handle-key or data-key → original buffer/param
BMap-->>Lower: original buffer (+ layout?)
alt layout exists
Lower->>Layout: linear offset → multi-dim indices
Layout-->>Lower: Forward(indices) → remapped indices
Lower->>Layout: compute new total_offset & new_buffer.data
Lower->>Rewriter: build rewritten call (new data ptr, new offset)
Rewriter-->>Lower: rewritten call
Lower-->>IR: replace original call with rewritten call
else no layout
Lower-->>IR: leave call unchanged
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/transform/lower_tile_op.cc`:
- Around line 346-379: The new shape may include a prepended replicate dimension
created by makeBufferWithLayout, so before indexing forward_indices in
lower_tile_op (where forward_indices = layout->Forward(multi_dim_indices) and
new_shape is used to compute new_offset) compute the replicate index as
elem_offset divided by the layout extent (use the product of
layout->OutputShape() or a provided layout_extent) (simplify with
analyzer_->Simplify if needed) and prepend that replicate index to
forward_indices (or assert that forward_indices.size() == new_shape.size());
this ensures forward_indices aligns with new_shape and prevents out-of-range
access when the output rank increases due to shared-layout replication.
src/transform/lower_tile_op.cc
Outdated
| // Get original and new buffer shapes | ||
| Array<PrimExpr> old_shape = original_buffer->shape; | ||
| Array<PrimExpr> new_shape = new_buffer->shape; | ||
| // Convert linear offset to multi-dimensional indices | ||
| Array<PrimExpr> multi_dim_indices; | ||
| PrimExpr remaining_offset = elem_offset; | ||
| for (int i = static_cast<int>(old_shape.size()) - 1; i >= 0; --i) { | ||
| multi_dim_indices.insert( | ||
| multi_dim_indices.begin(), | ||
| analyzer_->Simplify(floormod(remaining_offset, old_shape[i]))); | ||
| remaining_offset = floordiv(remaining_offset, old_shape[i]); | ||
| } | ||
| // Apply layout transformation | ||
| auto forward_indices = layout->Forward(multi_dim_indices); | ||
|
|
||
| PrimExpr new_offset = 0; | ||
| PrimExpr stride_offset = 1; | ||
| for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { | ||
| new_offset += forward_indices[i] * stride_offset; | ||
| stride_offset *= new_shape[i]; | ||
| } | ||
| new_offset = analyzer_->Simplify(new_offset); | ||
| Array<PrimExpr> new_indices; | ||
| for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { | ||
| new_indices.insert(new_indices.begin(), | ||
| floormod(new_offset, new_shape[i])); | ||
| new_offset = floordiv(new_offset, new_shape[i]); | ||
| } | ||
| PrimExpr total_offset = 0; | ||
| PrimExpr new_stride_offset = 1; | ||
| for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { | ||
| total_offset += new_indices[i] * new_stride_offset; | ||
| new_stride_offset *= new_shape[i]; | ||
| } |
There was a problem hiding this comment.
Handle shared-layout replication when output rank increases.
makeBufferWithLayout can prepend a replicate dimension for shared buffers. When that happens, new_shape.size() becomes layout->OutputShape().size() + 1, but the loop computing new_offset indexes forward_indices with the larger rank, causing out-of-range access or incorrect offsets. Please prepend the replicate index (derived from elem_offset / layout_extent) or assert rank consistency before using forward_indices.
🛠️ Proposed fix
- // Apply layout transformation
- auto forward_indices = layout->Forward(multi_dim_indices);
-
- PrimExpr new_offset = 0;
- PrimExpr stride_offset = 1;
- for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
- new_offset += forward_indices[i] * stride_offset;
- stride_offset *= new_shape[i];
- }
+ // Apply layout transformation
+ auto forward_indices = layout->Forward(multi_dim_indices);
+ // If makeBufferWithLayout prepended a replicate dim, prepend it here too.
+ if (new_shape.size() == forward_indices.size() + 1) {
+ int64_t layout_extent = 1;
+ for (const auto& s : layout->OutputShape()) {
+ const auto* imm = s.as<IntImmNode>();
+ ICHECK(imm) << "Layout output shape must be constant integer";
+ layout_extent *= imm->value;
+ }
+ PrimExpr replicate_idx =
+ floordiv(elem_offset, IntImm(elem_offset->dtype, layout_extent));
+ forward_indices.insert(forward_indices.begin(), replicate_idx);
+ } else {
+ ICHECK_EQ(new_shape.size(), forward_indices.size())
+ << "Layout output rank mismatch for remapped buffer";
+ }
+
+ PrimExpr new_offset = 0;
+ PrimExpr stride_offset = 1;
+ for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
+ new_offset += forward_indices[i] * stride_offset;
+ stride_offset *= new_shape[i];
+ }🤖 Prompt for AI Agents
In `@src/transform/lower_tile_op.cc` around lines 346 - 379, The new shape may
include a prepended replicate dimension created by makeBufferWithLayout, so
before indexing forward_indices in lower_tile_op (where forward_indices =
layout->Forward(multi_dim_indices) and new_shape is used to compute new_offset)
compute the replicate index as elem_offset divided by the layout extent (use the
product of layout->OutputShape() or a provided layout_extent) (simplify with
analyzer_->Simplify if needed) and prepend that replicate index to
forward_indices (or assert that forward_indices.size() == new_shape.size());
this ensures forward_indices aligns with new_shape and prevents out-of-range
access when the output rank increases due to shared-layout replication.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/lower_tile_op.cc (1)
443-449: Same replicate dimension issue as tvm_access_ptr branch.This loop also iterates over
new_shape.size()but indexes intoforward_indices, which will be one element smaller when a replicate dimension is prepended. Apply the same fix as recommended for the tvm_access_ptr branch.
♻️ Duplicate comments (1)
src/transform/lower_tile_op.cc (1)
354-361: Handle shared-layout replication when output rank increases.The same issue from the previous review remains:
makeBufferWithLayoutcan prepend a replicate dimension for shared buffers. When that happens,new_shape.size()becomesforward_indices.size() + 1, but the loop indexesforward_indiceswith the larger rank, causing out-of-range access.
🧹 Nitpick comments (2)
src/transform/lower_tile_op.cc (2)
363-374: Consider simplifying redundant index decomposition.After computing
new_offset(line 362), the code decomposes it tonew_indicesand then recomputestotal_offset. This round-trip should be mathematically equivalent to usingnew_offsetdirectly when indices are in range. If this is intentional (e.g., for normalization), a brief comment would clarify the intent.♻️ Suggested simplification
new_offset = analyzer_->Simplify(new_offset); - Array<PrimExpr> new_indices; - for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { - new_indices.insert(new_indices.begin(), - floormod(new_offset, new_shape[i])); - new_offset = floordiv(new_offset, new_shape[i]); - } - PrimExpr total_offset = 0; - PrimExpr new_stride_offset = 1; - for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { - total_offset += new_indices[i] * new_stride_offset; - new_stride_offset *= new_shape[i]; - } layout_remap_.Set(new_buffer, layout); // Build new tvm_access_ptr call with new buffer and offset Array<PrimExpr> new_args = access_ptr_call->args; new_args.Set(1, new_buffer->data); // Replace data var - new_args.Set(2, total_offset); // Replace offset + new_args.Set(2, new_offset); // Replace offset
428-431: Remove or use dead code.
buffer_row_sizeis computed but explicitly discarded with(void). If theCheckAndGetBufferRowSizecall is needed for its assertion side effect, add a comment. Otherwise, remove the unused computation.♻️ Suggested fix (remove if not needed)
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(remap_key->data)); - - int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); - (void)buffer_row_size; + // Validate buffer has at least 2 dimensions (row-major assumption) + CheckAndGetBufferRowSize(buffer_map_iter->second);
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/lower_tile_op.cc (1)
435-441: Address same rank mismatch issue and missing cross-tile offset handling.This
address_ofpath has the same rank mismatch issue as thetvm_access_ptrpath when shared buffers have a replicate dimension. Additionally, unlike thetvm_access_ptrpath (line 364), this path doesn't addremaining_offset * stride_offsetto handle accesses beyond one tile.🛠️ Suggested fix
auto forward_indices = layout.value()->Forward(multi_dim_indices); + // Handle prepended replicate dimension for shared buffers + if (new_shape.size() == forward_indices.size() + 1) { + forward_indices.insert(forward_indices.begin(), remaining_offset); + remaining_offset = IntImm(remaining_offset->dtype, 0); + } PrimExpr new_offset = 0; PrimExpr stride_offset = 1; for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { new_offset += forward_indices[i] * stride_offset; stride_offset *= new_shape[i]; } + // Add remaining offset for accesses beyond one tile + new_offset += remaining_offset * stride_offset; new_offset = analyzer_->Simplify(new_offset);
♻️ Duplicate comments (1)
src/transform/lower_tile_op.cc (1)
354-365: Handle shared-layout replication when output rank increases.When
makeBufferWithLayoutprepends a replicate dimension for shared buffers (whenreplicate_extent > 1),new_shape.size()becomeslayout->OutputShape().size() + 1. The loop at line 358 iterates overnew_shape.size()and accessesforward_indices[i], butforward_indicesonly haslayout->OutputShape().size()elements, causing an out-of-range access.The
remaining_offsethandling at line 364 partially addresses replication, but the loop bounds still don't account for the prepended dimension.
🧹 Nitpick comments (1)
src/transform/lower_tile_op.cc (1)
420-423: Remove dead code:buffer_row_sizeis computed but unused.
CheckAndGetBufferRowSizeis called and the result is immediately discarded with(void). This is likely leftover from previous implementation or incomplete logic. Either use this value or remove the computation.- int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); - (void)buffer_row_size;
…el source prints; update LowerLDSMCopy to handle non-full range cases and clean up unused code in copy.cc; add debug prints in OptimizeForTarget for shared memory allocation.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/op/copy.cc`:
- Around line 776-789: The full-range check uses the wrong region vector:
local_region is set to src_range when is_ldmatrix is true but it should
correspond to local_tensor (dst for LDSM, src for STSM). Change the assignment
of local_region to use dst_range when is_ldmatrix and src_range otherwise
(local_region = is_ldmatrix ? dst_range : src_range), then leave the existing
analyzer->CanProveEqual loop and fallback to LowerNormalCopy(T, analyzer) as-is.
In `@tilelang/engine/phase.py`:
- Around line 275-279: Replace the unconditional prints around the
MergeSharedMemoryAllocations pass with gated debug output: detect the existing
debug/pass-config flag (the same flag used by the C++ pass) or use a module
logger and only dump the IRModule (mod) when that flag/logger is enabled; apply
this change to the block calling
tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)
so the "Before/After MergeSharedMemoryAllocations" messages and IR dumps are
emitted only when debugging is enabled.
tilelang/engine/phase.py
Outdated
| print("Before MergeSharedMemoryAllocations") | ||
| print(mod) | ||
| mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) | ||
| print("After MergeSharedMemoryAllocations") | ||
| print(mod) |
There was a problem hiding this comment.
Gate IR dumps behind a debug flag.
Unconditional print(mod) will spam logs and can be expensive for large IRModules. Please guard this behind a pass-config (ideally the same flag used by the C++ pass) or a debug logger.
💡 Suggested change (guarded debug output)
- print("Before MergeSharedMemoryAllocations")
- print(mod)
+ debug_merge = bool(pass_ctx.config.get("tl.debug_merge_shared_memory_allocations", False))
+ if debug_merge:
+ print("Before MergeSharedMemoryAllocations")
+ print(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod)
- print("After MergeSharedMemoryAllocations")
- print(mod)
+ if debug_merge:
+ print("After MergeSharedMemoryAllocations")
+ print(mod)🤖 Prompt for AI Agents
In `@tilelang/engine/phase.py` around lines 275 - 279, Replace the unconditional
prints around the MergeSharedMemoryAllocations pass with gated debug output:
detect the existing debug/pass-config flag (the same flag used by the C++ pass)
or use a module logger and only dump the IRModule (mod) when that flag/logger is
enabled; apply this change to the block calling
tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)
so the "Before/After MergeSharedMemoryAllocations" messages and IR dumps are
emitted only when debugging is enabled.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/lower_tile_op.cc (1)
466-472: Same rank mismatch issue in address_of path.The address_of handling has the same potential out-of-range access when
makeBufferWithLayoutprepends a replicate dimension. Apply the same fix as suggested for thetvm_access_ptrpath.🛠️ Proposed fix
auto forward_indices = layout.value()->Forward(multi_dim_indices); + // Handle replicate dimension if present + if (new_shape.size() == forward_indices.size() + 1) { + int64_t layout_extent = 1; + for (const auto& s : layout.value()->OutputShape()) { + const auto* imm = s.as<IntImmNode>(); + ICHECK(imm) << "Layout output shape must be constant integer"; + layout_extent *= imm->value; + } + PrimExpr replicate_idx = + floordiv(smem_offset, IntImm(smem_offset->dtype, layout_extent)); + forward_indices.insert(forward_indices.begin(), replicate_idx); + } else { + ICHECK_EQ(new_shape.size(), forward_indices.size()) + << "Layout output rank mismatch for remapped buffer"; + } + PrimExpr new_offset = 0; PrimExpr stride_offset = 1; for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { new_offset += forward_indices[i] * stride_offset; stride_offset *= new_shape[i]; }
♻️ Duplicate comments (1)
src/transform/lower_tile_op.cc (1)
369-376: Handle shared-layout replication when output rank increases.
makeBufferWithLayoutcan prepend a replicate dimension for shared buffers (see lines 74-76). When that happens,new_shape.size()becomesforward_indices.size() + 1, but the loop indexesforward_indiceswith the larger rank, causing out-of-range access or incorrect offsets.🛠️ Proposed fix
// Apply layout transformation auto forward_indices = layout->Forward(multi_dim_indices); + // If makeBufferWithLayout prepended a replicate dim, prepend it here too. + if (new_shape.size() == forward_indices.size() + 1) { + int64_t layout_extent = 1; + for (const auto& s : layout->OutputShape()) { + const auto* imm = s.as<IntImmNode>(); + ICHECK(imm) << "Layout output shape must be constant integer"; + layout_extent *= imm->value; + } + PrimExpr replicate_idx = + floordiv(elem_offset, IntImm(elem_offset->dtype, layout_extent)); + forward_indices.insert(forward_indices.begin(), replicate_idx); + } else { + ICHECK_EQ(new_shape.size(), forward_indices.size()) + << "Layout output rank mismatch for remapped buffer"; + } + PrimExpr new_offset = 0; PrimExpr stride_offset = 1; for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { new_offset += forward_indices[i] * stride_offset; stride_offset *= new_shape[i]; }
🧹 Nitpick comments (2)
src/transform/lower_tile_op.cc (2)
593-600: Clarify the intent with explicit reassignment after mutation.After calling
CopyOnWrite()and modifyingcall_node->args, the code reads back fromcall->args[5]. While this works becauseCopyOnWrite()modifies the object in-place when uniquely referenced, the pattern is subtle and could be clearer. Consider explicitly re-assigning for readability:♻️ Suggested improvement for clarity
if (!load_expr.same_as(access_ptr_call->args[0])) { auto call_node = call.CopyOnWrite(); - call_node->args.Set( - 5, Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, - access_ptr_call->annotations, access_ptr_call->span)); - access_ptr_call = Downcast<Call>(call->args[5]); - access_ptr = call->args[5]; + PrimExpr new_access_call = + Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, + access_ptr_call->annotations, access_ptr_call->span); + call_node->args.Set(5, new_access_call); + access_ptr = new_access_call; + access_ptr_call = Downcast<Call>(new_access_call); }
561-563: Consider clarifying the is_ptx_ recursion guard intent.The early return when
is_ptx_is true acts as a recursion guard during child visitation. This prevents double-processing but assumes PTX intrinsics won't contain nested PTX calls that need independent transformation. While this is likely safe in practice, a brief comment explaining this invariant would help maintainability.📝 Suggested comment
if (is_ptx_) { + // Recursion guard: when visiting children of a PTX intrinsic, skip + // re-processing any nested PTX calls (not expected in practice). return Downcast<Call>(op); }
as title.
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Chores
✏️ Tip: You can customize this high-level summary in your review settings.