[Feat] Add tilelang T.assume support and assume injection for buffer shapes#787
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughInserts a TL/TIR pass InjectAssumes that rewrites PrimFunc bodies to add deduplicated runtime buffer-shape checks as Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Engine as LowerAndLegalize
participant TL as tilelang.transform
participant Pass as tl.InjectAssumes
participant Func as PrimFunc
Engine->>TL: FrontendLegalize(mod)
Engine->>TL: InjectAssumes()(mod) %% new insertion
TL->>Pass: apply to PrimFunc(s)
Pass->>Func: rewrite body (insert AttrStmt tilelang_assume checks)
Func-->>TL: transformed PrimFunc(s)
Engine->>TL: continue remaining passes
sequenceDiagram
autonumber
participant Pass as InjectAssumes
participant Mut as AssumeInjector
participant Func as PrimFunc
participant Analyzer as arith::Analyzer
Pass->>Mut: Substitute(Func)
Mut->>Mut: collect shapes from DeclBuffer / Block / buffer_map / alloc_buffers / match_buffers
Mut->>Mut: deduplicate shapes (structural hash/eq) and map buffers→shape
Mut->>Analyzer: simplify GT(shape, 0)
alt non-trivial condition
Mut->>Func: attach AttrStmt(tilelang_assume, cond, "buffers: ...")
else trivial/no shape
Mut-->>Func: no assume emitted
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Summary of Changes
Hello @kurisu6912, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new optimization pass aimed at accelerating the TVM prover by programmatically injecting assumptions about buffer dimensions into the intermediate representation. By making these constraints explicit, the prover can operate more efficiently. Furthermore, the changes include the addition of debugging tools that allow developers to save and inspect the IR at different points in the compilation pipeline, which will be valuable for understanding and troubleshooting transformations.
Highlights
- New InjectAssumes Pass: Implemented a C++ pass (InjectAssumes) that automatically inserts assertions into the IR, ensuring buffer shapes are positive. This is intended to provide explicit constraints for the TVM prover.
- Compilation Pipeline Integration: The InjectAssumes pass has been integrated into the LowerAndLegalize phase of the TileLang compilation pipeline, executing after frontend legalization.
- Enhanced Debugging Capabilities: Added functionality to dump intermediate IR modules to a 'debug' directory at various stages of the compilation process, facilitating easier inspection and debugging.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a new InjectAssumes TIR transform pass to speed up the TVM prover by adding assertions that buffer shapes are positive. The changes include the C++ implementation of the pass, its Python bindings, and its integration into the compilation pipeline. My review focuses on correctness and best practices. I've identified a logic bug in the new C++ pass that could lead to incorrect behavior, an inefficiency in the same file, and some leftover debugging code in the Python pipeline that should be removed. I've also suggested an improvement to a docstring for better clarity.
src/transform/inject_assumes.cc
Outdated
| Stmt VisitStmt_(const BlockNode* op) final { | ||
| auto body = VisitStmt(op->body); | ||
| AssertCreator c; | ||
| if(root_node) { | ||
| for(auto item: f->buffer_map) { | ||
| c.addBuffer(item.second); | ||
| } | ||
| } |
There was a problem hiding this comment.
The root_node flag is intended to process function parameters only for the outermost block. However, it's never set to false, causing f->buffer_map to be processed for every BlockNode in the function body, which is incorrect and inefficient. You should capture the root_node state, set it to false before recursion, and then use the captured state.
Stmt VisitStmt_(const BlockNode* op) final {
bool is_root = root_node;
if (is_root) root_node = false;
auto body = VisitStmt(op->body);
AssertCreator c;
if(is_root) {
for(auto item: f->buffer_map) {
c.addBuffer(item.second);
}
}
tilelang/engine/phase.py
Outdated
| debug_path = Path('debug') | ||
| debug_path.mkdir(exist_ok=True) |
There was a problem hiding this comment.
This code unconditionally creates a debug directory in the current working directory. This is a side effect that can be problematic (e.g., if write permissions are not available) and is generally not desirable in library code. This, along with the various debug_path.joinpath(...).write_text(...) calls (e.g., on lines 103, 106, 145, 147, 211), appears to be debugging code that should be removed before merging.
src/transform/inject_assumes.cc
Outdated
| auto bucket = buckets[h]; | ||
| auto it = std::find_if(bucket.begin(), bucket.end(), [&](auto y) { | ||
| return se(e, y, true); | ||
| }); | ||
| if(it == bucket.end()) { | ||
| exprs.push_back(e); | ||
| buckets[h].push_back(e); | ||
| } |
There was a problem hiding this comment.
Accessing the buckets map with buckets[h] creates a copy of the std::vector<tvm::PrimExpr>, which is inefficient. You should use a reference (auto&) to avoid this unnecessary copy. This also allows you to simplify the code by using the reference to push back the new element.
| auto bucket = buckets[h]; | |
| auto it = std::find_if(bucket.begin(), bucket.end(), [&](auto y) { | |
| return se(e, y, true); | |
| }); | |
| if(it == bucket.end()) { | |
| exprs.push_back(e); | |
| buckets[h].push_back(e); | |
| } | |
| auto& bucket = buckets[h]; | |
| auto it = std::find_if(bucket.begin(), bucket.end(), [&](auto y) { | |
| return se(e, y, true); | |
| }); | |
| if(it == bucket.end()) { | |
| exprs.push_back(e); | |
| bucket.push_back(e); | |
| } |
| """Inject Assumes | ||
|
|
||
| Returns: | ||
| ------- | ||
| fpass : tvm.transform.Pass | ||
| The result pass | ||
| """ |
There was a problem hiding this comment.
The docstring is a bit sparse and the Returns: section is inconsistent with other functions in this file. It would be helpful to explain what is being injected and why, and to align with the existing docstring format.
| """Inject Assumes | |
| Returns: | |
| ------- | |
| fpass : tvm.transform.Pass | |
| The result pass | |
| """ | |
| """Inject assertions that buffer shapes are positive to speed up the prover. | |
| Returns | |
| ------- | |
| fpass : tvm.transform.Pass | |
| The result pass | |
| """ |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
src/transform/inject_assumes.cc (3)
31-38: Avoid copying buckets when deduplicating shapes.This copies the vector; use a reference.
- auto bucket = buckets[h]; + auto& bucket = buckets[h];
48-51: Make zero-constant dtype-stable.Comparing with literal
0may upcast to int32 and mismatch 64-bit shape dtypes. ConsiderGT(expr, make_const(expr.dtype(), 0))(orCast(expr->dtype, 0)).
96-101: Pass lambda should take const refs to avoid copies (align with other passes).Not critical, but matches surrounding style and avoids unnecessary copies.
- auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto pass_func = [=](PrimFunc f, const IRModule& m, const PassContext& ctx) {tilelang/transform/__init__.py (1)
89-89: Optional: silence type checkers for _ffi_api dynamic attr.Other wrappers use
# type: ignore; add it for consistency.- return _ffi_api.InjectAssumes() + return _ffi_api.InjectAssumes() # type: ignoretilelang/engine/phase.py (2)
65-67: Gate debug directory creation; avoid side effects on import.Creating
debug/at import time can be undesirable in sandboxes and breaks read-only environments.-debug_path = Path('debug') -debug_path.mkdir(exist_ok=True) +import os +DEBUG_DUMP = bool(int(os.getenv("TL_DEBUG_IR", "0"))) +debug_path = Path(os.getenv("TL_DEBUG_DIR", "debug")) if DEBUG_DUMP else None +if DEBUG_DUMP: + debug_path.mkdir(parents=True, exist_ok=True)
103-106: Guard IR dumps behind a flag to reduce I/O and noise.Unconditional writes can slow runs and pollute working dirs. Use the
DEBUG_DUMPgate.-debug_path.joinpath('LowerTileOp.0.py').write_text(mod.script(show_meta=True)) +if DEBUG_DUMP: + debug_path.joinpath('LowerTileOp.0.py').write_text(mod.script(show_meta=True)) @@ -debug_path.joinpath('LowerTileOp.1.py').write_text(mod.script(show_meta=True)) +if DEBUG_DUMP: + debug_path.joinpath('LowerTileOp.1.py').write_text(mod.script(show_meta=True)) @@ -debug_path.joinpath('InjectFenceProxy.0.py').write_text(mod.script(show_meta=True)) +if DEBUG_DUMP: + debug_path.joinpath('InjectFenceProxy.0.py').write_text(mod.script(show_meta=True)) @@ -debug_path.joinpath('InjectFenceProxy.1.py').write_text(mod.script(show_meta=True)) +if DEBUG_DUMP: + debug_path.joinpath('InjectFenceProxy.1.py').write_text(mod.script(show_meta=True)) @@ -debug_path.joinpath('PersistThreadblock.1.py').write_text(mod.script(show_meta=True)) +if DEBUG_DUMP: + debug_path.joinpath('PersistThreadblock.1.py').write_text(mod.script(show_meta=True))Also applies to: 145-147, 211-211
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/transform/inject_assumes.cc(1 hunks)tilelang/engine/phase.py(5 hunks)tilelang/transform/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/transform/__init__.py (1)
src/transform/inject_assumes.cc (2)
InjectAssumes(96-101)InjectAssumes(96-96)
src/transform/inject_assumes.cc (1)
tilelang/transform/__init__.py (1)
InjectAssumes(81-89)
tilelang/engine/phase.py (4)
src/transform/inject_assumes.cc (2)
InjectAssumes(96-101)InjectAssumes(96-96)tilelang/transform/__init__.py (3)
InjectAssumes(81-89)LowerTileOp(49-57)InjectFenceProxy(228-236)src/transform/lower_tile_op.cc (2)
LowerTileOp(572-577)LowerTileOp(572-572)src/transform/inject_fence_proxy.cc (3)
InjectFenceProxy(184-184)InjectFenceProxy(189-195)InjectFenceProxy(189-189)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: format-check
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (3)
src/transform/inject_assumes.cc (1)
53-55: Use tvm::String for AssertStmt message (StringImm is wrong type).Second parameter to
AssertStmtistvm::String, not aPrimExpr.- auto msg = StringImm("Invalid Buffer Shape: buffer shape should be greater than 0"); - return AssertStmt(simplified, msg, body); + return AssertStmt( + simplified, + tvm::String("Invalid Buffer Shape: buffer shape should be greater than 0"), + body);Likely an incorrect or invalid review comment.
tilelang/transform/__init__.py (1)
81-90: Wrapper looks good; expose via FFI is consistent.tilelang/engine/phase.py (1)
95-97: Good placement of InjectAssumes in the pipeline.Running after FrontendLegalize and before Simplify maximizes simplification opportunities.
| #include "tvm/arith/analyzer.h" | ||
| #include "tvm/ir/expr.h" | ||
| #include "tvm/ir/transform.h" | ||
| #include "tvm/node/structural_hash.h" | ||
| #include "tvm/tir/expr.h" | ||
| #include "tvm/tir/stmt.h" | ||
| #include "tvm/tir/stmt_functor.h" | ||
| #include "tvm/tir/transform.h" |
There was a problem hiding this comment.
Add missing headers to prevent build failures.
std::unordered_map, std::vector, and std::find_if require STL headers; StructuralEqual needs its header too.
Apply:
#include "tvm/arith/analyzer.h"
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+#include "tvm/node/structural_equal.h"
#include "tvm/ir/expr.h"
#include "tvm/ir/transform.h"
#include "tvm/node/structural_hash.h"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include "tvm/arith/analyzer.h" | |
| #include "tvm/ir/expr.h" | |
| #include "tvm/ir/transform.h" | |
| #include "tvm/node/structural_hash.h" | |
| #include "tvm/tir/expr.h" | |
| #include "tvm/tir/stmt.h" | |
| #include "tvm/tir/stmt_functor.h" | |
| #include "tvm/tir/transform.h" | |
| #include "tvm/arith/analyzer.h" | |
| #include <algorithm> | |
| #include <unordered_map> | |
| #include <vector> | |
| #include "tvm/node/structural_equal.h" | |
| #include "tvm/ir/expr.h" | |
| #include "tvm/ir/transform.h" | |
| #include "tvm/node/structural_hash.h" | |
| #include "tvm/tir/expr.h" | |
| #include "tvm/tir/stmt.h" | |
| #include "tvm/tir/stmt_functor.h" | |
| #include "tvm/tir/transform.h" |
🤖 Prompt for AI Agents
In src/transform/inject_assumes.cc around lines 2 to 9, the file is missing STL
headers and the StructuralEqual definition: add #include <unordered_map>,
#include <vector>, and #include <algorithm> to provide std::unordered_map,
std::vector, and std::find_if, and include the appropriate TVM header for
StructuralEqual (e.g., #include "tvm/node/structural_equal.h") near the other
TVM includes; reorder/includes should follow project style and then rebuild to
verify the missing symbol errors are resolved.
src/transform/inject_assumes.cc
Outdated
| Stmt VisitStmt_(const BlockNode* op) final { | ||
| auto body = VisitStmt(op->body); | ||
| AssertCreator c; | ||
| if(root_node) { | ||
| for(auto item: f->buffer_map) { | ||
| c.addBuffer(item.second); | ||
| } | ||
| } | ||
| for(auto item: op->alloc_buffers) { |
There was a problem hiding this comment.
Root-block detection is broken; toggle before recursion and only once.
root_node is never flipped and you recurse before toggling, so every nested Block acts like root and duplicates buffer_map asserts.
- Stmt VisitStmt_(const BlockNode* op) final {
- auto body = VisitStmt(op->body);
- AssertCreator c;
- if(root_node) {
+ Stmt VisitStmt_(const BlockNode* op) final {
+ bool was_root = root_node;
+ root_node = false; // ensure children are not treated as root
+ AssertCreator c;
+ if (was_root) {
for(auto item: f->buffer_map) {
c.addBuffer(item.second);
}
}
+ auto body = VisitStmt(op->body);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Stmt VisitStmt_(const BlockNode* op) final { | |
| auto body = VisitStmt(op->body); | |
| AssertCreator c; | |
| if(root_node) { | |
| for(auto item: f->buffer_map) { | |
| c.addBuffer(item.second); | |
| } | |
| } | |
| for(auto item: op->alloc_buffers) { | |
| Stmt VisitStmt_(const BlockNode* op) final { | |
| bool was_root = root_node; | |
| root_node = false; // ensure children are not treated as root | |
| AssertCreator c; | |
| if (was_root) { | |
| for(auto item: f->buffer_map) { | |
| c.addBuffer(item.second); | |
| } | |
| } | |
| auto body = VisitStmt(op->body); | |
| for(auto item: op->alloc_buffers) { |
🤖 Prompt for AI Agents
In src/transform/inject_assumes.cc around lines 63-71, root_block detection is
broken because root_node is never flipped and you recurse before toggling;
change the logic so you detect and set root_node true before recursing into the
block body (but only if it was false), run the root-only buffer_map handling
while root_node is true, then restore root_node to its previous value after the
recursive visit so nested Blocks don't all act as root and duplicate asserts.
…shapes (tile-ai#787) * Add InjectAssumes pass to speedup tvm prover * Fix lint errors * remove debug statements * [Feat] add assume attr and assume support in tilelang * Add convertion from tir.assume to tilelang assume * [Fix] Add missing With constraint in IRMutator * Fix typo in ir mutator
Summary by CodeRabbit
New Features
Chores