[Enhancement] Enhance let binding handling in layout inference and warp specialized pass#1484
Conversation
…rence * Introduced a new static method `FullyReplicated` in the `Fragment` class to create fully replicated fragment layouts, ensuring all threads hold identical copies of the buffer. * Updated `CopyNode` to collect fragment layouts and mark them as fully replicated during layout inference. * Enhanced `ParallelOpNode` to expand let bindings for fragment buffer accesses, improving layout inference accuracy. * Added documentation for new methods and updated existing methods to support the new layout features.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughAdds let-binding-aware fragment layout inference and a new Fragment factory. LetStmt variable mappings are tracked and propagated through collection, inference, and lowering so fragment buffers referenced via let-bound expressions are detected and can be created as fully replicated fragments. Changes
Sequence Diagram(s)sequenceDiagram
participant LetCollector as LetStmt Collector
participant BufferCollector as BufferUseDefCollector
participant LayoutInfer as Layout Inference
participant ParallelOp as ParallelOpNode
participant CopyOp as CopyNode
participant FragFactory as Fragment Factory
Note over LetCollector,BufferCollector: Phase 1 — Record let bindings & find fragment refs
LetCollector->>BufferCollector: record var -> expr mappings
BufferCollector->>BufferCollector: CollectFragmentBuffersFromExpr(expr) (follow lets)
Note over LayoutInfer: Phase 2 — Infer layouts with let mappings
LayoutInfer->>ParallelOp: InferLayout(..., let_var_to_expr)
ParallelOp->>ParallelOp: ExpandLetBindings(let_var_to_expr)
ParallelOp-->>LayoutInfer: updated indice_map / fragment uses
LayoutInfer->>CopyOp: InferLayout(..., let_var_to_expr)
CopyOp->>CopyOp: CollectFragmentLayouts(expr, let_var_to_expr, ...)
CopyOp->>FragFactory: request FullyReplicated(shape, thread_extent)
FragFactory-->>CopyOp: Layout for fragment
Note over LayoutInfer: Phase 3 — Propagate layouts into Lowering
LayoutInfer-->>Lowering: pass layout_map and let_var_to_expr via LowerArgs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…amline output and improve performance.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/op/parallel.cc (1)
185-211: Let-binding expansion logic is sound; consider avoiding repeated work / const_cast.The new
ExpandLetBindingscorrectly followslet_var_to_exprchains and discoverslocal.fragmentBufferLoads that would otherwise be invisible toParallelLoopNestVisitor, which is exactly what layout inference needs for let-bound fragments.Two small polish points you may want to consider:
InferLayoutisconstbut mutatesindice_map_viaconst_cast. Ifindice_map_is conceptually part of the analysis state, making itmutableand markingExpandLetBindingsconstwould avoid the const-cast and better reflect intent.InferLayoutmay be called multiple times per operator at differentInferLevels;ExpandLetBindingswill re-traverse alllet_var_to_exprexpressions each time even afterindice_map_has been fully populated. If this shows up in profiles, a simple boolean “let_expanded_” flag or caching of visited vars could skip redundant traversals.Given typical let-binding counts this is unlikely to be a bottleneck, so these are optional cleanups rather than blockers.
Also applies to: 246-249
src/transform/layout_inference.cc (1)
322-333: Consider downgrading verbose layout_map dumps to debug logging.Dumping the entire
layout_mapwithLOG(INFO)after bothFinishInferQueueandInferInFreeModeis very helpful while bringing up this feature, but in real models this can produce a large amount of log spam on the default INFO level.If you still need this for debugging, switching to
DLOG(INFO)(or gating behind a PassContext flag) would keep default logs cleaner while preserving the diagnostics when explicitly enabled.testing/python/language/test_tilelang_language_let_layout.py (1)
36-47: Tidy up unused kernel/test parameters and temporaries.A few minor cleanups will make this test clearer and keep lint quiet:
- Line 36:
bxfromwith T.Kernel(...) as (bx, by)is never used. Renaming it to_or_bxwould make the intent explicit and satisfy linters.- Lines 37–42:
B_sharedis allocated and cleared but never read or written afterwards. If it’s not part of the intended scenario, consider removing it (and theT.clear(B_shared)call); if it is, wiring it into the copy path would make the test’s purpose clearer.- Line 52:
ref_blocksparse_copytakesNbut never uses it. Either drop the parameter (and correspondingly adjust the call at lines 95–96) or add a simple sanity check involvingNto justify its presence.These are all non-functional nits, but addressing them will reduce noise from tools like Ruff.
Also applies to: 52-64
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
src/layout/layout.cc(1 hunks)src/layout/layout.h(1 hunks)src/op/copy.cc(4 hunks)src/op/copy.h(1 hunks)src/op/fill.cc(2 hunks)src/op/operator.h(2 hunks)src/op/parallel.cc(2 hunks)src/op/parallel.h(1 hunks)src/transform/layout_inference.cc(5 hunks)src/transform/layout_reducer.cc(1 hunks)src/transform/lower_tile_op.cc(1 hunks)src/transform/warp_specialized_rewriter.cc(3 hunks)testing/python/language/test_tilelang_language_let_layout.py(1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/warp_specialized_rewriter.ccsrc/transform/layout_inference.cc
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_let_layout.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_let_layout.py
🧬 Code graph analysis (10)
src/layout/layout.h (1)
src/layout/layout.cc (4)
Fragment(516-538)Fragment(540-550)FullyReplicated(552-556)FullyReplicated(552-553)
src/op/parallel.h (1)
src/op/parallel.cc (2)
ExpandLetBindings(185-211)ExpandLetBindings(185-186)
src/transform/warp_specialized_rewriter.cc (2)
src/transform/lower_tile_op.cc (2)
expr(387-399)expr(387-387)src/transform/cluster_planning.cc (2)
var(93-93)var(93-93)
src/op/parallel.cc (3)
src/transform/layout_inference.cc (6)
expr(548-573)expr(548-548)expr(575-585)expr(575-575)expr(781-794)expr(781-781)src/transform/lower_tile_op.cc (2)
expr(387-399)expr(387-387)src/transform/warp_specialized_rewriter.cc (6)
expr(65-74)expr(65-65)expr(77-89)expr(77-77)var(370-376)var(370-370)
src/layout/layout.cc (1)
tilelang/layout/fragment.py (1)
Fragment(13-205)
src/op/copy.cc (1)
src/layout/layout.cc (2)
FullyReplicated(552-556)FullyReplicated(552-553)
testing/python/language/test_tilelang_language_let_layout.py (3)
tilelang/language/allocate.py (1)
alloc_fragment(72-85)tilelang/language/copy_op.py (1)
copy(14-95)tilelang/language/loop.py (1)
Pipelined(58-95)
src/transform/layout_reducer.cc (1)
src/layout/layout.cc (2)
FullyReplicated(552-556)FullyReplicated(552-553)
src/op/copy.h (1)
src/op/copy.cc (2)
CollectFragmentLayouts(2063-2086)CollectFragmentLayouts(2063-2068)
src/transform/layout_inference.cc (1)
src/transform/warp_specialized_rewriter.cc (12)
op(38-43)op(38-38)op(91-94)op(91-91)op(96-106)op(96-96)expr(65-74)expr(65-65)expr(77-89)expr(77-77)var(370-376)var(370-370)
🪛 Ruff (0.14.8)
testing/python/language/test_tilelang_language_let_layout.py
36-36: Unpacked variable bx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
52-52: Unused function argument: N
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
src/op/copy.h (1)
272-293: LGTM! Well-documented private helper method.The new
CollectFragmentLayoutsmethod is properly declared with comprehensive documentation that clearly explains its purpose, parameters, and behavior. The documentation follows Doxygen style and the method encapsulation is appropriate.src/op/parallel.h (1)
108-111: LGTM! Clear method declaration with helpful example.The
ExpandLetBindingsmethod is well-documented with a concrete example showing its use case. The private visibility and simple signature are appropriate for this helper method.src/op/fill.cc (1)
161-168: LGTM! Consistent updates to InferLayout calls.Both call sites correctly pass an empty map for
let_var_to_expr, which is appropriate since fill operations don't involve let-bound expressions. The explicitInferLevel::kFreespecification improves code clarity.Also applies to: 184-191
src/transform/lower_tile_op.cc (1)
641-650: LGTM! Proper propagation of let bindings to lowering.The conversion from
let_bindings_toMap<Var, PrimExpr>is straightforward and correctly passes the mapping to the tile operator'sLowermethod. This enables fragment buffer resolution through let bindings during lowering.src/op/operator.h (1)
42-44: LGTM! Consistent struct extensions for let binding support.Both
LowerArgsandLayoutInferArgsare properly extended with thelet_var_to_exprfield, maintaining consistency across the lowering and layout inference interfaces. The documentation clearly explains the purpose of this new field.Also applies to: 54-56
src/layout/layout.cc (1)
552-556: LGTM! Clean factory method for fully replicated fragments.The
FullyReplicatedstatic factory provides a clear and concise way to create fragments where all threads hold identical buffer copies. The implementation correctly uses an emptyforward_indexandReplicationPlaceholder()to represent full replication.src/layout/layout.h (1)
178-190: LGTM! Excellent API documentation for new factory method.The declaration of
FullyReplicatedincludes comprehensive documentation that clearly explains the concept of full replication, its purpose, and use cases. TheTVM_DLLexport macro is appropriately used for this public API.src/transform/layout_reducer.cc (1)
216-216: LGTM! Improved code clarity with factory method.The change to use
Fragment::FullyReplicatedis more expressive and maintainable than the previous constructor approach. It clearly conveys the intent that all threads hold identical copies forALLreplication type reducers.src/transform/warp_specialized_rewriter.cc (1)
51-54: Let-aware producer buffer discovery matches cluster_planning pattern and looks correct.The additions to
ProducerUsedBufferFinder(trackinglet_var_to_expr_,CollectBuffersFromExpr, and wiring it intoInsertBufferand TMA call handling) make this pass properly follow let-bound expressions when computingproducer_buffers_. The structure mirrors the existingCollectFragmentBuffersFromExprin other passes, recursion is bounded by LetStmt nesting, and the use ofstd::unordered_set<const BufferNode*>avoids duplication concerns.From a maintainability point of view this is in a good shape and aligns well with the rest of the pipeline.
Also applies to: 65-90, 91-94, 125-132, 135-137
src/transform/layout_inference.cc (1)
113-120: Let-binding-aware fragment collection is consistent and correctly wired through layout inference.Passing
let_var_to_expr_intoLayoutInferArgs, recording bindings inVisitStmt_(LetStmtNode), and usingCollectFragmentBuffersFromExprfromVisitExpr_(CallNode)givesBufferUseDefCollectorthe ability to seelocal.fragmentbuffers that are only reachable through let-bound expressions. This matches the pattern used in other passes and ensures fragment buffers contributing indices (e.g.a = block_mask_f[i]) are added touse_list_and hence participate in layout inference.The recursive
CollectFragmentBuffersFromExpris safe given TIR’s acyclic LetStmt structure, and the extra indirection vialet_var_to_expr_is only invoked when needed (on Var occurrences).No functional issues here; this closes the gap the tests are targeting.
Also applies to: 493-497, 772-777, 779-795, 856-858
src/op/copy.cc (1)
559-577: Bulk copy now correctly infers fully-replicated layouts for fragment index buffers.The new
CollectFragmentLayoutshelper plus theresult_mapwiring in the bulk path cleanly address fragment buffers used as indices in bulk load/store:
- Traversing
src_range/dst_rangeexpressions (including vialet_var_to_expr) ensures that fragment-scoped buffers likeblock_mask_fare discovered even when only referenced through let-bound indices.- For such
local.fragmentbuffers with no existing layout, assigningFragment::FullyReplicated(buffer->shape, thread_extent)->BindThreadRange(thread_bounds)matches the intended “every thread has the same contents” semantics.- Returning
result_mapthat includes both any new fragment layouts and a linearshared_tensorlayout keeps this inference localized to the bulk path and avoids perturbing other operators.This integrates well with the updated
LayoutInferArgsand should give Bulk/TMA paths the same fragment-awareness as the rest of layout inference.Also applies to: 582-586, 2063-2086
This pull request introduces significant improvements to the handling of "fragment" buffer layouts, especially regarding their propagation through let bindings, and enhances layout inference and lowering logic for tile operators. The changes ensure that fragment buffers accessed indirectly via let variables are correctly recognized and assigned fully replicated layouts. Additionally, the PR refactors and generalizes the codebase to support these improvements across copy, fill, and parallel operations.
Fragment Buffer Layout Inference and Propagation:
Fragment::FullyReplicatedto create layouts where all threads hold identical copies of a buffer, streamlining the handling of index/mask fragments. [1] [2] [3]CopyNode::CollectFragmentLayouts, ensuring all relevant buffers are marked as fully replicated during layout inference. [1] [2] [3]Let Binding Tracking and Propagation:
LowerArgs,LayoutInferArgs) to include alet_var_to_exprmap, enabling operators to resolve fragment buffer accesses through let bindings. [1] [2] [3]ParallelOpNodeand related layout inference logic to expand let bindings and find fragment buffer accesses, ensuring correct propagation of layout information. [1] [2] [3]Generalization Across Operators:
let_var_to_expr, ensuring consistent fragment buffer handling in all relevant operators. [1] [2] [3]Analysis and Debugging Improvements:
BufferUseDefCollectorto track let bindings, recursively collect fragment buffers through expressions, and provide debug logging of inferred layouts. [1] [2] [3] [4] [5]Supporting Changes:
ProducerUsedBufferFinderto clear and use let binding information for buffer collection, ensuring correct buffer usage tracking in warp-specialized rewriting. [1] [2]These changes collectively improve the robustness and correctness of fragment buffer layout inference, particularly in complex scenarios involving indirect accesses via let bindings.
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.