Skip to content

[Enhancement] Enhance let binding handling in layout inference and warp specialized pass#1484

Merged
LeiWang1999 merged 3 commits intotile-ai:mainfrom
LeiWang1999:let_1221
Dec 20, 2025
Merged

[Enhancement] Enhance let binding handling in layout inference and warp specialized pass#1484
LeiWang1999 merged 3 commits intotile-ai:mainfrom
LeiWang1999:let_1221

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 20, 2025

This pull request introduces significant improvements to the handling of "fragment" buffer layouts, especially regarding their propagation through let bindings, and enhances layout inference and lowering logic for tile operators. The changes ensure that fragment buffers accessed indirectly via let variables are correctly recognized and assigned fully replicated layouts. Additionally, the PR refactors and generalizes the codebase to support these improvements across copy, fill, and parallel operations.

Fragment Buffer Layout Inference and Propagation:

  • Added a new static method Fragment::FullyReplicated to create layouts where all threads hold identical copies of a buffer, streamlining the handling of index/mask fragments. [1] [2] [3]
  • Implemented recursive collection of fragment buffers from expressions (including those accessed through let bindings) in CopyNode::CollectFragmentLayouts, ensuring all relevant buffers are marked as fully replicated during layout inference. [1] [2] [3]

Let Binding Tracking and Propagation:

  • Extended layout inference and lowering argument structs (LowerArgs, LayoutInferArgs) to include a let_var_to_expr map, enabling operators to resolve fragment buffer accesses through let bindings. [1] [2] [3]
  • Modified ParallelOpNode and related layout inference logic to expand let bindings and find fragment buffer accesses, ensuring correct propagation of layout information. [1] [2] [3]

Generalization Across Operators:

  • Updated fill and copy operator lowering and layout inference to pass along let_var_to_expr, ensuring consistent fragment buffer handling in all relevant operators. [1] [2] [3]

Analysis and Debugging Improvements:

  • Enhanced the BufferUseDefCollector to track let bindings, recursively collect fragment buffers through expressions, and provide debug logging of inferred layouts. [1] [2] [3] [4] [5]

Supporting Changes:

  • Updated the ProducerUsedBufferFinder to clear and use let binding information for buffer collection, ensuring correct buffer usage tracking in warp-specialized rewriting. [1] [2]

These changes collectively improve the robustness and correctness of fragment buffer layout inference, particularly in complex scenarios involving indirect accesses via let bindings.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added a static factory to create fully replicated fragments more easily.
    • Layout inference now recognizes fragment buffers referenced via let-bindings, improving replication handling across copy and parallel paths.
    • Bulk operation layout inference enhanced to track and populate fragment buffer layouts for replicated buffers.
  • Tests

    • Added tests validating fragment layout inference and copy paths with let-bound fragment accesses.

✏️ Tip: You can customize this high-level summary in your review settings.

…rence

* Introduced a new static method `FullyReplicated` in the `Fragment` class to create fully replicated fragment layouts, ensuring all threads hold identical copies of the buffer.
* Updated `CopyNode` to collect fragment layouts and mark them as fully replicated during layout inference.
* Enhanced `ParallelOpNode` to expand let bindings for fragment buffer accesses, improving layout inference accuracy.
* Added documentation for new methods and updated existing methods to support the new layout features.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 20, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Adds let-binding-aware fragment layout inference and a new Fragment factory. LetStmt variable mappings are tracked and propagated through collection, inference, and lowering so fragment buffers referenced via let-bound expressions are detected and can be created as fully replicated fragments.

Changes

Cohort / File(s) Change Summary
Fragment Factory Method
src/layout/layout.h, src/layout/layout.cc
Added Fragment::FullyReplicated(Array<PrimExpr> shape, PrimExpr thread_extent) static factory to create fully-replicated fragments.
Operator Argument Structures
src/op/operator.h
Added Map<Var, PrimExpr> let_var_to_expr to LowerArgs and LayoutInferArgs to propagate LetStmt bindings into lowering and layout inference.
Copy Operation Layout Inference
src/op/copy.h, src/op/copy.cc
Added CopyNode::CollectFragmentLayouts(...) to collect fragment-local buffers (resolving let bindings) and produce replicated layouts; InferLayout now returns populated layout maps for bulk paths and uses par_op_->InferLayout result consistently.
Parallel Operation Let-Binding Handling
src/op/parallel.h, src/op/parallel.cc
Added ParallelOpNode::ExpandLetBindings(const Map<Var,PrimExpr>&) to traverse let-bound expressions and populate internal indices; invoked from InferLayout when let mappings are provided.
Fill Operation Layout Inference
src/op/fill.cc
Calls to InferLayout for local.fragment/shared scopes now pass an empty layout map and explicit InferLevel::kFree to the parallel op's InferLayout.
Layout Inference Let-Binding Support
src/transform/layout_inference.cc
BufferUseDefCollector now records let_var_to_expr_, visits LetStmtNode to record bindings, and adds CollectFragmentBuffersFromExpr to resolve fragment buffers reachable via let-bound expressions; passes let mappings into LayoutInferArgs.
Layout Reducer Fragment Construction
src/transform/layout_reducer.cc
Replaced prior direct constructor usage with Fragment::FullyReplicated(buffer->shape, thread_extent) when creating ALL-replicated fragments.
Lower Tile Operation Let-Binding Propagation
src/transform/lower_tile_op.cc
Converted internal let_bindings_ to Map<Var,PrimExpr> let_var_to_expr and passed it through LowerArgs into tile_op->Lower.
Warp Specialization Let-Binding Buffer Tracking
src/transform/warp_specialized_rewriter.cc
ProducerUsedBufferFinder gains let_var_to_expr_, VisitStmt_(LetStmtNode), and CollectBuffersFromExpr so producer buffer discovery follows let-bindings.
Fragment Layout Let-Binding Tests
testing/python/language/test_tilelang_language_let_layout.py
New tests exercising fragment layout inference with LetStmt-bound indices and validating TMA and CP.ASYNC copy paths (blocksparse_copy kernel, reference, and test harnesses).

Sequence Diagram(s)

sequenceDiagram
    participant LetCollector as LetStmt Collector
    participant BufferCollector as BufferUseDefCollector
    participant LayoutInfer as Layout Inference
    participant ParallelOp as ParallelOpNode
    participant CopyOp as CopyNode
    participant FragFactory as Fragment Factory

    Note over LetCollector,BufferCollector: Phase 1 — Record let bindings & find fragment refs
    LetCollector->>BufferCollector: record var -> expr mappings
    BufferCollector->>BufferCollector: CollectFragmentBuffersFromExpr(expr) (follow lets)

    Note over LayoutInfer: Phase 2 — Infer layouts with let mappings
    LayoutInfer->>ParallelOp: InferLayout(..., let_var_to_expr)
    ParallelOp->>ParallelOp: ExpandLetBindings(let_var_to_expr)
    ParallelOp-->>LayoutInfer: updated indice_map / fragment uses

    LayoutInfer->>CopyOp: InferLayout(..., let_var_to_expr)
    CopyOp->>CopyOp: CollectFragmentLayouts(expr, let_var_to_expr, ...)
    CopyOp->>FragFactory: request FullyReplicated(shape, thread_extent)
    FragFactory-->>CopyOp: Layout for fragment

    Note over LayoutInfer: Phase 3 — Propagate layouts into Lowering
    LayoutInfer-->>Lowering: pass layout_map and let_var_to_expr via LowerArgs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Inspect recursive let-binding traversal (CollectFragmentBuffersFromExpr / CollectBuffersFromExpr) for completeness and cycles.
  • Verify let_var_to_expr correctly propagates to all InferLayout and Lower calls (fill, copy, parallel, tile lowering).
  • Review const-correctness and side-effects where ExpandLetBindings is invoked (cast-away-const pattern).
  • Confirm tests exercise both TMA and CP.ASYNC paths and that new factory is used consistently.

Possibly related PRs

Poem

🐰 In tunnels of Let and index light,

I hop through bindings, soft and bright.
Each fragment found, a mirrored hue,
Fully replicated — stitched anew.
I nibble bugs and map the way,
Hopping layouts into play. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main focus of the PR: enhancing let binding handling in layout inference and the warp specialized pass, which aligns with the core changes across multiple files.

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1301d68 and b743030.

📒 Files selected for processing (1)
  • src/transform/layout_inference.cc (4 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/op/parallel.cc (1)

185-211: Let-binding expansion logic is sound; consider avoiding repeated work / const_cast.

The new ExpandLetBindings correctly follows let_var_to_expr chains and discovers local.fragment BufferLoads that would otherwise be invisible to ParallelLoopNestVisitor, which is exactly what layout inference needs for let-bound fragments.

Two small polish points you may want to consider:

  • InferLayout is const but mutates indice_map_ via const_cast. If indice_map_ is conceptually part of the analysis state, making it mutable and marking ExpandLetBindings const would avoid the const-cast and better reflect intent.
  • InferLayout may be called multiple times per operator at different InferLevels; ExpandLetBindings will re-traverse all let_var_to_expr expressions each time even after indice_map_ has been fully populated. If this shows up in profiles, a simple boolean “let_expanded_” flag or caching of visited vars could skip redundant traversals.

Given typical let-binding counts this is unlikely to be a bottleneck, so these are optional cleanups rather than blockers.

Also applies to: 246-249

src/transform/layout_inference.cc (1)

322-333: Consider downgrading verbose layout_map dumps to debug logging.

Dumping the entire layout_map with LOG(INFO) after both FinishInferQueue and InferInFreeMode is very helpful while bringing up this feature, but in real models this can produce a large amount of log spam on the default INFO level.

If you still need this for debugging, switching to DLOG(INFO) (or gating behind a PassContext flag) would keep default logs cleaner while preserving the diagnostics when explicitly enabled.

testing/python/language/test_tilelang_language_let_layout.py (1)

36-47: Tidy up unused kernel/test parameters and temporaries.

A few minor cleanups will make this test clearer and keep lint quiet:

  • Line 36: bx from with T.Kernel(...) as (bx, by) is never used. Renaming it to _ or _bx would make the intent explicit and satisfy linters.
  • Lines 37–42: B_shared is allocated and cleared but never read or written afterwards. If it’s not part of the intended scenario, consider removing it (and the T.clear(B_shared) call); if it is, wiring it into the copy path would make the test’s purpose clearer.
  • Line 52: ref_blocksparse_copy takes N but never uses it. Either drop the parameter (and correspondingly adjust the call at lines 95–96) or add a simple sanity check involving N to justify its presence.

These are all non-functional nits, but addressing them will reduce noise from tools like Ruff.

Also applies to: 52-64

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 168aec7 and 1301d68.

📒 Files selected for processing (13)
  • src/layout/layout.cc (1 hunks)
  • src/layout/layout.h (1 hunks)
  • src/op/copy.cc (4 hunks)
  • src/op/copy.h (1 hunks)
  • src/op/fill.cc (2 hunks)
  • src/op/operator.h (2 hunks)
  • src/op/parallel.cc (2 hunks)
  • src/op/parallel.h (1 hunks)
  • src/transform/layout_inference.cc (5 hunks)
  • src/transform/layout_reducer.cc (1 hunks)
  • src/transform/lower_tile_op.cc (1 hunks)
  • src/transform/warp_specialized_rewriter.cc (3 hunks)
  • testing/python/language/test_tilelang_language_let_layout.py (1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/warp_specialized_rewriter.cc
  • src/transform/layout_inference.cc
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_let_layout.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_let_layout.py
🧬 Code graph analysis (10)
src/layout/layout.h (1)
src/layout/layout.cc (4)
  • Fragment (516-538)
  • Fragment (540-550)
  • FullyReplicated (552-556)
  • FullyReplicated (552-553)
src/op/parallel.h (1)
src/op/parallel.cc (2)
  • ExpandLetBindings (185-211)
  • ExpandLetBindings (185-186)
src/transform/warp_specialized_rewriter.cc (2)
src/transform/lower_tile_op.cc (2)
  • expr (387-399)
  • expr (387-387)
src/transform/cluster_planning.cc (2)
  • var (93-93)
  • var (93-93)
src/op/parallel.cc (3)
src/transform/layout_inference.cc (6)
  • expr (548-573)
  • expr (548-548)
  • expr (575-585)
  • expr (575-575)
  • expr (781-794)
  • expr (781-781)
src/transform/lower_tile_op.cc (2)
  • expr (387-399)
  • expr (387-387)
src/transform/warp_specialized_rewriter.cc (6)
  • expr (65-74)
  • expr (65-65)
  • expr (77-89)
  • expr (77-77)
  • var (370-376)
  • var (370-370)
src/layout/layout.cc (1)
tilelang/layout/fragment.py (1)
  • Fragment (13-205)
src/op/copy.cc (1)
src/layout/layout.cc (2)
  • FullyReplicated (552-556)
  • FullyReplicated (552-553)
testing/python/language/test_tilelang_language_let_layout.py (3)
tilelang/language/allocate.py (1)
  • alloc_fragment (72-85)
tilelang/language/copy_op.py (1)
  • copy (14-95)
tilelang/language/loop.py (1)
  • Pipelined (58-95)
src/transform/layout_reducer.cc (1)
src/layout/layout.cc (2)
  • FullyReplicated (552-556)
  • FullyReplicated (552-553)
src/op/copy.h (1)
src/op/copy.cc (2)
  • CollectFragmentLayouts (2063-2086)
  • CollectFragmentLayouts (2063-2068)
src/transform/layout_inference.cc (1)
src/transform/warp_specialized_rewriter.cc (12)
  • op (38-43)
  • op (38-38)
  • op (91-94)
  • op (91-91)
  • op (96-106)
  • op (96-96)
  • expr (65-74)
  • expr (65-65)
  • expr (77-89)
  • expr (77-77)
  • var (370-376)
  • var (370-370)
🪛 Ruff (0.14.8)
testing/python/language/test_tilelang_language_let_layout.py

36-36: Unpacked variable bx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


52-52: Unused function argument: N

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
src/op/copy.h (1)

272-293: LGTM! Well-documented private helper method.

The new CollectFragmentLayouts method is properly declared with comprehensive documentation that clearly explains its purpose, parameters, and behavior. The documentation follows Doxygen style and the method encapsulation is appropriate.

src/op/parallel.h (1)

108-111: LGTM! Clear method declaration with helpful example.

The ExpandLetBindings method is well-documented with a concrete example showing its use case. The private visibility and simple signature are appropriate for this helper method.

src/op/fill.cc (1)

161-168: LGTM! Consistent updates to InferLayout calls.

Both call sites correctly pass an empty map for let_var_to_expr, which is appropriate since fill operations don't involve let-bound expressions. The explicit InferLevel::kFree specification improves code clarity.

Also applies to: 184-191

src/transform/lower_tile_op.cc (1)

641-650: LGTM! Proper propagation of let bindings to lowering.

The conversion from let_bindings_ to Map<Var, PrimExpr> is straightforward and correctly passes the mapping to the tile operator's Lower method. This enables fragment buffer resolution through let bindings during lowering.

src/op/operator.h (1)

42-44: LGTM! Consistent struct extensions for let binding support.

Both LowerArgs and LayoutInferArgs are properly extended with the let_var_to_expr field, maintaining consistency across the lowering and layout inference interfaces. The documentation clearly explains the purpose of this new field.

Also applies to: 54-56

src/layout/layout.cc (1)

552-556: LGTM! Clean factory method for fully replicated fragments.

The FullyReplicated static factory provides a clear and concise way to create fragments where all threads hold identical buffer copies. The implementation correctly uses an empty forward_index and ReplicationPlaceholder() to represent full replication.

src/layout/layout.h (1)

178-190: LGTM! Excellent API documentation for new factory method.

The declaration of FullyReplicated includes comprehensive documentation that clearly explains the concept of full replication, its purpose, and use cases. The TVM_DLL export macro is appropriately used for this public API.

src/transform/layout_reducer.cc (1)

216-216: LGTM! Improved code clarity with factory method.

The change to use Fragment::FullyReplicated is more expressive and maintainable than the previous constructor approach. It clearly conveys the intent that all threads hold identical copies for ALL replication type reducers.

src/transform/warp_specialized_rewriter.cc (1)

51-54: Let-aware producer buffer discovery matches cluster_planning pattern and looks correct.

The additions to ProducerUsedBufferFinder (tracking let_var_to_expr_, CollectBuffersFromExpr, and wiring it into InsertBuffer and TMA call handling) make this pass properly follow let-bound expressions when computing producer_buffers_. The structure mirrors the existing CollectFragmentBuffersFromExpr in other passes, recursion is bounded by LetStmt nesting, and the use of std::unordered_set<const BufferNode*> avoids duplication concerns.

From a maintainability point of view this is in a good shape and aligns well with the rest of the pipeline.

Also applies to: 65-90, 91-94, 125-132, 135-137

src/transform/layout_inference.cc (1)

113-120: Let-binding-aware fragment collection is consistent and correctly wired through layout inference.

Passing let_var_to_expr_ into LayoutInferArgs, recording bindings in VisitStmt_(LetStmtNode), and using CollectFragmentBuffersFromExpr from VisitExpr_(CallNode) gives BufferUseDefCollector the ability to see local.fragment buffers that are only reachable through let-bound expressions. This matches the pattern used in other passes and ensures fragment buffers contributing indices (e.g. a = block_mask_f[i]) are added to use_list_ and hence participate in layout inference.

The recursive CollectFragmentBuffersFromExpr is safe given TIR’s acyclic LetStmt structure, and the extra indirection via let_var_to_expr_ is only invoked when needed (on Var occurrences).

No functional issues here; this closes the gap the tests are targeting.

Also applies to: 493-497, 772-777, 779-795, 856-858

src/op/copy.cc (1)

559-577: Bulk copy now correctly infers fully-replicated layouts for fragment index buffers.

The new CollectFragmentLayouts helper plus the result_map wiring in the bulk path cleanly address fragment buffers used as indices in bulk load/store:

  • Traversing src_range/dst_range expressions (including via let_var_to_expr) ensures that fragment-scoped buffers like block_mask_f are discovered even when only referenced through let-bound indices.
  • For such local.fragment buffers with no existing layout, assigning Fragment::FullyReplicated(buffer->shape, thread_extent)->BindThreadRange(thread_bounds) matches the intended “every thread has the same contents” semantics.
  • Returning result_map that includes both any new fragment layouts and a linear shared_tensor layout keeps this inference localized to the bulk path and avoids perturbing other operators.

This integrates well with the updated LayoutInferArgs and should give Bulk/TMA paths the same fragment-awareness as the rest of layout inference.

Also applies to: 582-586, 2063-2086

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant