-
Notifications
You must be signed in to change notification settings - Fork 333
[Refactor] Update TVM subproject and refactor BlockNode handling in warp_specialized_rewriter.cc #812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAnnotates warp-specialized functions with a new AttrStmt using Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Rewriter as WarpSpecializedRewriter
participant Func as PrimFunc
participant Attr as AttrStmt(kCustomWarpSpecialization)
Note over Rewriter,Func: Warp specialization annotation
Rewriter->>Func: Analyze function for warp specialization
alt warp-specialized
Rewriter->>Attr: Create AttrStmt(key="kCustomWarpSpecialization", value=1)
Attr->>Func: Wrap function body with AttrStmt
Rewriter-->>Func: Return annotated function
else not warp-specialized
Rewriter->>Func: Return Substitute(... ) (unchanged path)
end
sequenceDiagram
autonumber
participant Injector as SetMaxNRegInjector
participant Collector as SetMaxNRegCollector
participant Func as PrimFunc
Note over Injector,Collector: Register-allocation collection & injection
Injector->>Collector: Collect(Func)
Collector->>Func: Visit body (including AttrStmt nodes)
alt AttrStmt(kCustomWarpSpecialization) detected
Func-->>Collector: Attr present -> Collector sets warp_specialized_ = true
Collector-->>Injector: Return empty nreg_ array
Injector->>Injector: Skip injection (nreg_ empty)
else No annotation
Collector-->>Injector: Return collected nreg_ values
alt nreg_ non-empty
Injector->>Func: Inject set_maxnreg hints
else nreg_ empty
Injector->>Injector: Skip injection
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chengyupku, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on updating the TVM subproject to its latest commit and enhancing the "warp_specialized_rewriter.cc" file. The refactoring in the C++ file streamlines the processing of "BlockNode" and "BlockRealizeNode" by adopting a unified "FilterByRole" mechanism, thereby improving the robustness and readability of the code.
Highlights
- TVM Subproject Update: The TVM subproject has been updated to commit "87b845fa0e14c2029bbf5799fbbbb9d490db4f20".
- Code Refactoring: The handling of "BlockNode" and "BlockRealizeNode" in "src/transform/warp_specialized_rewriter.cc" has been refactored to consistently use the "FilterByRole" function, improving code clarity and maintainability by removing previously unreachable "ICHECK(0)" statements.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the TVM subproject and refactors warp_specialized_rewriter.cc. The refactoring to handle BlockNode and BlockRealizeNode is a good step towards making the code more robust. However, I have a significant concern about the implementation. The current approach of overriding the VisitStmt_ methods with a direct call to FilterByRole breaks the recursive nature of StmtMutator. This could lead to statements inside these blocks being silently ignored, causing correctness issues. My review comment provides more details on this potential bug and suggests a path for correction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overriding VisitStmt_ for BlockNode and BlockRealizeNode to call FilterByRole(op) directly prevents the default recursive behavior of StmtMutator. These nodes are containers, and a StmtMutator should typically recurse into their body (e.g., by calling this->VisitStmt(op->body)).
By not performing this recursion, any statements within these blocks will be skipped by WSCodeEmitter, potentially leading to incorrect code generation. The same potential issue exists for the LetStmtNode visitor in this class.
The previous ICHECK(0) may have been a safeguard against this exact scenario, indicating that a simple filtering approach is insufficient. Please refactor this to ensure that the bodies of these container nodes are visited recursively, for example by manually implementing the recursion before applying filtering logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/transform/warp_specialized_rewriter.cc (1)
892-893: Optional: always traverse container nodes (Block/BlockRealize) even when roles match.Currently, when the role matches the emitter,
FilterByRolereturns the node unchanged and skips traversal. For container nodes, you may still want to descend to enable nested rewrites (e.g., inner For with tl_pipeline_* annotations) even if the entire block is single‑role.Apply this tweak if you want guaranteed traversal for containers while preserving role-based elision:
- Stmt VisitStmt_(const BlockNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockNode *op) final { + Stmt out = FilterByRole(op); + if (out.same_as(GetRef<Stmt>(op)) && marker_.GetRole(op) != Role::kBoth && !mbarrier_only_) { + // Container: still traverse to allow nested transformations. + return StmtMutator::VisitStmt_(op); + } + return out; + } - Stmt VisitStmt_(const BlockRealizeNode *op) final { return FilterByRole(op); } + Stmt VisitStmt_(const BlockRealizeNode *op) final { + Stmt out = FilterByRole(op); + if (out.same_as(GetRef<Stmt>(op)) && marker_.GetRole(op) != Role::kBoth && !mbarrier_only_) { + // Container: still traverse to allow nested transformations. + return StmtMutator::VisitStmt_(op); + } + return out; + }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between ae9b706 and 43fc6e72f3ed6313d09e9b1c7b66675cdf1ffa6f.
📒 Files selected for processing (2)
3rdparty/tvm(1 hunks)src/transform/warp_specialized_rewriter.cc(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
PR: tile-ai/tilelang#794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/warp_specialized_rewriter.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
🔇 Additional comments (2)
3rdparty/tvm (1)
1-1: TVM submodule bump to 87b845f — small TIR refactor; verify downstream compatibility, CI, and licensingCompare eddefbd65acb7b1ea51dd18068b4049754c4fa7a -> 87b845fa0e14c2029bbf5799fbbbb9d490db4f20 (1 commit: "Refactor BlockReadWriteDetector analysis on BlockRealizeNode") — changed file: src/tir/analysis/block_access_region_detector.cc.
- Confirm warp_specialized_rewriter.cc still builds and FilterByRole() behavior is unchanged; run unit/e2e tests covering Block and BlockRealize paths.
- Check for any local TVM patches that may have been dropped; reapply or document removals.
- Ensure CI performs deterministic submodule updates (git submodule sync + git submodule update --init --depth …) and caches builds to avoid flakiness.
- Verify license/NOTICE remains correct (TVM = Apache‑2.0) and no new third‑party code requires additional notices.
src/transform/warp_specialized_rewriter.cc (1)
892-893: Adopt FilterByRole for Block/BlockRealize — good cleanup and consistent with role marking.This removes the previous “unreachable” stance and lets nested blocks flow through the same producer/consumer filtering as other nodes. Given WarpSpecializedRoleMarker assigns roles to both BlockNode and BlockRealizeNode, this is coherent.
Please confirm via tests that eliding a Block/BlockRealize with
Evaluate(0)in non-selected roles does not violate any downstream TIR invariants (e.g., contexts expecting a BlockRealize as an immediate child). A small unit covering nested blocks with mixed roles would be ideal.
…warp group register allocation - Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework. - Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly. - Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency. - Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations.
43fc6e7 to
ee8aca1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/op/builtin.h (1)
28-28: Name/key consistency for warp specialization attributeYou now have multiple warp‑specialization keys in the codebase (e.g., string literal "warp_specialize" in warp_specialized_rewriter.h vs this new kCustomWarpSpecialization). Consider centralizing on tl::attr constants for both to avoid drift and make usages grep‑able. If the intent is a function‑level marker, a brief comment here would help clarify semantics.
Option A (add alias for existing literal to de‑duplicate string usage across files):
namespace attr { static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kWarpSpecializationScope = "kWarpSpecializationScope"; +// Function-level marker set by WarpSpecialized pass to indicate warp specialization is active. +static constexpr const char *kWarpSpecialize = "warp_specialize"; static constexpr const char *kCustomWarpSpecialization = "kCustomWarpSpecialization"; }src/transform/annotate_warp_group_reg_alloc.cc (2)
49-55: Also check value == 1 and keep backward compatibility with "warp_specialize"Right now any presence of the key flips the flag. Guard on IntImm==1, and (optionally) accept the legacy "warp_specialize" key to be robust to mixed modules.
Apply:
- void VisitStmt_(const AttrStmtNode *op) final { - if (op->attr_key == attr::kCustomWarpSpecialization) { - warp_specialized_ = true; - } - StmtExprVisitor::VisitStmt_(op); - } + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == attr::kCustomWarpSpecialization || + op->attr_key == "warp_specialize") { + if (const auto* imm = op->value.as<IntImmNode>()) { + warp_specialized_ = warp_specialized_ || (imm->value == 1); + } + } + StmtExprVisitor::VisitStmt_(op); + }
67-69: Defensive check on nreg_ size before indexing laterCollect guarantees size 0 or 2, but an explicit assertion prevents accidental OOB if invariants change.
Apply:
T.nreg_ = SetMaxNRegCollector::Collect(f); if (T.nreg_.size() == 0) { return f; } + ICHECK_EQ(T.nreg_.size(), 2) << "Expected 2 register hints or empty; got " << T.nreg_.size();
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 43fc6e72f3ed6313d09e9b1c7b66675cdf1ffa6f and ee8aca1.
📒 Files selected for processing (3)
src/op/builtin.h(1 hunks)src/transform/annotate_warp_group_reg_alloc.cc(2 hunks)src/transform/warp_specialized_rewriter.cc(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/transform/warp_specialized_rewriter.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/annotate_warp_group_reg_alloc.cc (2)
src/transform/warp_specialized_rewriter.cc (18)
op(38-43)op(38-38)op(73-83)op(73-73)op(85-93)op(85-85)op(95-100)op(95-95)op(102-110)op(102-102)op(134-146)op(134-134)op(148-177)op(148-148)op(179-189)op(179-179)f(1121-1140)f(1121-1122)src/transform/warp_specialized_rewriter.h (1)
VisitStmt_(77-90)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (1)
src/transform/annotate_warp_group_reg_alloc.cc (1)
20-22: Early exit sentinel for warp‑specialized functions — pass order confirmedWarpSpecialized injects attr::kCustomWarpSpecialization; AnnotateWarpGroupRegAlloc detects that attr and returns early; the engine pipeline invokes WarpSpecialized before AnnotateWarpGroupRegAlloc.
Locations: src/transform/warp_specialized_rewriter.cc (AttrStmt(..., attr::kCustomWarpSpecialization) at ~1288), src/transform/annotate_warp_group_reg_alloc.cc (VisitStmt_ checks attr::kCustomWarpSpecialization at ~48–52), tilelang/engine/phase.py (WarpSpecialized() then AnnotateWarpGroupRegAlloc() at ~128–130).
…arp_specialized_rewriter.cc (tile-ai#812) * [Feature] Introduce custom warp specialization attribute and enhance warp group register allocation - Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework. - Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly. - Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency. - Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations. * lint * lint
warp_specialized_rewriter.ccto utilize theFilterByRolefunction, improving code clarity and maintainability.Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Style