[Feature] Implement LoopUnswitching Pass#1747
Conversation
- Added a new pass to hoist loop-invariant if statements out of loops, improving optimization opportunities. - Introduced classes for collecting and checking conditions, as well as for rewriting statements. - Integrated the new pass into the optimization pipeline and provided a corresponding API for usage.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughAdds a new LoopUnswitching C++ transform with FFI binding, exposes it in Python, inserts it into the OptimizeForTarget pipeline after StorageRewrite, adds a pass-config toggle and builtin attr key, and includes unit and functional tests validating loop-invariant conditional hoisting. Changes
sequenceDiagram
participant IR as IR Module
participant LU as LoopUnswitching Pass
participant WVC as WrittenVarCollector
participant HIF as HoistableIfFinder
participant CLN as LoopCloner
participant OUT as Transformed IR
IR->>LU: ApplyLoopUnswitching(stmt)
LU->>WVC: collect written buffers / write effects
WVC-->>LU: written buffer set
LU->>HIF: locate hoistable IfThenElse and check invariance
HIF-->>LU: target If node + captured lets
LU->>CLN: clone loop -> then_loop (replace If with then branch)
CLN-->>LU: then_loop
LU->>CLN: clone loop -> else_loop (replace If with else branch)
CLN-->>LU: else_loop
LU->>OUT: emit IfThenElse selecting then_loop / else_loop
🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…ache disabling in benchmark script - Introduced CallNodeChecker class to identify CallNode expressions in loop conditions, enhancing loop-invariant checks. - Updated IsLoopInvariant function to reject conditions containing CallNodes, preventing potential side effects. - Added tilelang.disable_cache() in benchmark_mha_sink_fwd.py to optimize performance during benchmarking.
- Added support for Let-bound variables in the WrittenBufferReadChecker to improve buffer read checks. - Introduced UsesLoopVarThroughLetBindings function to check if conditions depend on loop variables through Let bindings. - Updated IsLoopInvariant function to account for Let bindings when determining loop invariance. - Enhanced HoistableIfFinder to track Let bindings for variables bound to BufferLoad expressions. - Added debug print statements in the OptimizeForTarget function to visualize the module state before and after loop unswitching.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/transform/loop_unswitching.cc`:
- Around line 306-313: The else loop reuses op->thread_binding which still
references the original loop var; create a mirrored IterVar (or copy) whose var
is else_loop_var when op->thread_binding is non-null and use that mirrored
IterVar in the else_loop construction so the IterVar in thread_binding matches
else_loop_var; locate the creation of Var else_loop_var and the For
else_loop(...) and replace op->thread_binding with a remapped copy of
op->thread_binding that references else_loop_var (preserving other
fields/annotations) before constructing else_loop.
- Around line 148-176: The current UsesLoopVarThroughLetBindings only checks
bound expressions that directly contain the loop var and misses chaining through
Var-to-Var lets; update HoistableIfFinder::VisitStmt_(const LetStmtNode*) to
record every LetStmt binding (map VarNode* -> PrimExpr) not just BufferLoad
cases, and rewrite UsesLoopVarThroughLetBindings to recursively resolve Var
bindings following the same pattern as WrittenBufferReadChecker so that when a
VarNode is encountered it looks up its binding and recursively checks that
binding (and avoids infinite recursion via visited set).
| bool UsesLoopVarThroughLetBindings( | ||
| const PrimExpr &cond, const Var &loop_var, | ||
| const std::unordered_map<const VarNode *, PrimExpr> *let_bindings) { | ||
| // Check if condition directly uses loop variable | ||
| if (UsesVar(cond, [&](const VarNode *v) { return v == loop_var.get(); })) { | ||
| return true; | ||
| } | ||
|
|
||
| // Check if any Let-bound variable used in condition has a binding that uses | ||
| // the loop variable | ||
| if (let_bindings) { | ||
| bool uses_loop_var = false; | ||
| PostOrderVisit(cond, [&](const ObjectRef &obj) { | ||
| if (uses_loop_var) | ||
| return; | ||
| if (const auto *var_node = obj.as<VarNode>()) { | ||
| auto it = let_bindings->find(var_node); | ||
| if (it != let_bindings->end()) { | ||
| // Check if the bound expression uses the loop variable | ||
| if (UsesVar(it->second, | ||
| [&](const VarNode *v) { return v == loop_var.get(); })) { | ||
| uses_loop_var = true; | ||
| } | ||
| } | ||
| } | ||
| }); | ||
| if (uses_loop_var) { | ||
| return true; | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n src/transform/loop_unswitching.cc | sed -n '140,180p'Repository: tile-ai/tilelang
Length of output: 1530
🏁 Script executed:
rg -B5 -A5 "HoistableIfFinder" src/transform/loop_unswitching.cc | head -100Repository: tile-ai/tilelang
Length of output: 1055
🏁 Script executed:
rg -A20 "void VisitStmt_\(const LetStmtNode" src/transform/loop_unswitching.ccRepository: tile-ai/tilelang
Length of output: 719
🏁 Script executed:
rg -B5 -A5 "IsLoopInvariant" src/transform/loop_unswitching.cc | head -150Repository: tile-ai/tilelang
Length of output: 797
🏁 Script executed:
find . -path "*/test*" -name "*loop_unswitching*" -o -path "*/test*" -name "*hoisting*" | head -20Repository: tile-ai/tilelang
Length of output: 130
🏁 Script executed:
cat -n ./testing/python/transform/test_tilelang_transform_loop_unswitching.pyRepository: tile-ai/tilelang
Length of output: 7271
🏁 Script executed:
rg -B10 "IsLoopInvariant" src/transform/loop_unswitching.cc | grep -A15 "bool IsLoopInvariant"Repository: tile-ai/tilelang
Length of output: 452
🏁 Script executed:
rg -A30 "^bool IsLoopInvariant" src/transform/loop_unswitching.ccRepository: tile-ai/tilelang
Length of output: 1115
🏁 Script executed:
rg -B5 -A20 "class WrittenBufferReadChecker" src/transform/loop_unswitching.ccRepository: tile-ai/tilelang
Length of output: 971
🏁 Script executed:
rg -A15 "void VisitExpr_\(const VarNode \*op\)" src/transform/loop_unswitching.ccRepository: tile-ai/tilelang
Length of output: 557
Store all LetStmt bindings and add recursive loop-variance checking.
Line 164-169 only checks if bound expressions directly use the loop variable. With chained Let bindings (e.g., v ← BufferLoad, w ← v), loop-variant dependencies or written-buffer reads are not detected. For example:
let v = A[0] // BufferLoad → stored in let_bindings
let w = v // Var binding → NOT stored
if w > 0 // uses w, but w's dependency on A[0] goes undetected
To fix this properly:
- Modify
HoistableIfFinder::VisitStmt_(const LetStmtNode*)to store all Let bindings (not just BufferLoad values) - Make
UsesLoopVarThroughLetBindingsrecursive so it follows binding chains (same pattern asWrittenBufferReadChecker)
Suggested recursive binding check
- if (UsesVar(it->second,
- [&](const VarNode *v) { return v == loop_var.get(); })) {
- uses_loop_var = true;
- }
+ if (UsesLoopVarThroughLetBindings(it->second, loop_var,
+ let_bindings)) {
+ uses_loop_var = true;
+ }🤖 Prompt for AI Agents
In `@src/transform/loop_unswitching.cc` around lines 148 - 176, The current
UsesLoopVarThroughLetBindings only checks bound expressions that directly
contain the loop var and misses chaining through Var-to-Var lets; update
HoistableIfFinder::VisitStmt_(const LetStmtNode*) to record every LetStmt
binding (map VarNode* -> PrimExpr) not just BufferLoad cases, and rewrite
UsesLoopVarThroughLetBindings to recursively resolve Var bindings following the
same pattern as WrittenBufferReadChecker so that when a VarNode is encountered
it looks up its binding and recursively checks that binding (and avoids infinite
recursion via visited set).
| // Create new loop_var for else_loop to maintain SSA form | ||
| Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype); | ||
| else_body = Substitute(else_body, {{op->loop_var, else_loop_var}}); | ||
|
|
||
| For then_loop(op->loop_var, op->min, op->extent, op->kind, then_body, | ||
| op->thread_binding, op->annotations); | ||
| For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body, | ||
| op->thread_binding, op->annotations); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file and locate the exact code in question
cat -n src/transform/loop_unswitching.cc | sed -n '300,320p'Repository: tile-ai/tilelang
Length of output: 983
🏁 Script executed:
# Let's also check the structure of IterVar and related classes
find . -name "*.h" -o -name "*.cc" | xargs rg -l "class IterVar" | head -5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for similar thread_binding patterns in the codebase to understand the expected usage
rg "thread_binding\.defined\(\)" --type cc -A 3 -B 1 | head -50Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
# Find IterVar definition
find . -type f -name "*.h" | xargs rg "class IterVar" -B 2 -A 10 | head -60Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for how thread_binding is properly handled when creating new loops
rg "IterVar.*thread" -B 3 -A 3 --type-list | head -80Repository: tile-ai/tilelang
Length of output: 449
🏁 Script executed:
# Search for For loop constructor to understand parameters
rg "class For|struct For" -B 2 -A 20 --type-list | head -100Repository: tile-ai/tilelang
Length of output: 449
🏁 Script executed:
# Look for patterns where thread_binding is reassigned or recreated
rg "thread_binding\s*=" -B 2 -A 2 | head -60Repository: tile-ai/tilelang
Length of output: 4376
🏁 Script executed:
# Search for .h files containing IterVar definition
find . -type f -name "*.h" | head -20Repository: tile-ai/tilelang
Length of output: 471
🏁 Script executed:
# Look for IterVar in header files
find . -type f -name "*.h" -exec grep -l "IterVar" {} \;Repository: tile-ai/tilelang
Length of output: 391
🏁 Script executed:
# Get the For loop and IterVar definitions
find . -type f -name "*.h" -exec grep -l "class For\|struct For" {} \;Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Let's examine parallel.h which likely has IterVar definition
cat -n src/op/parallel.h | head -150Repository: tile-ai/tilelang
Length of output: 6995
🏁 Script executed:
# Search for For definition more broadly
grep -r "class For\|struct For" --include="*.h" | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for IterVar constructor patterns
rg "IterVar\(" --type h -B 1 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for how For loops are constructed with thread_binding
rg "For\(" --type h --type cc -B 2 -A 1 | grep -A 3 "thread_binding" | head -60Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
# Look for IterVar construction patterns to understand the constructor
rg "IterVar\(" -B 1 -A 1 | head -100Repository: tile-ai/tilelang
Length of output: 5200
🏁 Script executed:
# Check if there are any existing thread_binding remapping patterns in the codebase
rg "thread_binding" -B 3 -A 3 --type cc | grep -A 5 -B 5 "IterVar\|Var(" | head -100Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
# Let's search for IterVar definition in the codebase more broadly to understand its constructor
find . -name "*.h" -o -name "*.cc" | xargs grep -l "IterVar" | head -10Repository: tile-ai/tilelang
Length of output: 289
🏁 Script executed:
# Search for thread_binding field access to understand its type
rg "thread_binding->" -B 2 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for For loop constructor definition
rg "For\(" -A 3 | grep -B 1 -A 3 "thread_binding" | head -60Repository: tile-ai/tilelang
Length of output: 3235
🏁 Script executed:
# Look at the src/ir.cc file to understand For constructor and thread_binding type
cat -n src/ir.cc | head -100Repository: tile-ai/tilelang
Length of output: 4520
🏁 Script executed:
# Search for For struct/class definition in TVM to understand thread_binding
rg "struct For\|class For" --include="*.h" -B 5 -A 20 | head -150Repository: tile-ai/tilelang
Length of output: 501
🏁 Script executed:
# Look at the actual loop_unswitching.cc file more broadly to see context
cat -n src/transform/loop_unswitching.cc | sed -n '290,330p'Repository: tile-ai/tilelang
Length of output: 1691
🏁 Script executed:
# Search for how thread_binding is used/accessed in the codebase
rg "thread_binding\." -B 2 -A 2 | head -100Repository: tile-ai/tilelang
Length of output: 1711
🏁 Script executed:
# Look for Optional<IterVar> to understand thread_binding type
rg "Optional.*IterVar\|thread_binding" -B 1 -A 1 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check if there are any existing examples of remapping or creating new thread_binding
rg "IterVar.*thread" -B 3 -A 3 | head -100Repository: tile-ai/tilelang
Length of output: 7078
🏁 Script executed:
# Check how thread_binding and loop_var are used together in code generation
rg "thread_binding\|loop_var" -B 2 -A 2 | grep -B 3 -A 3 "kThreadBinding" | head -100Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for ForKind::kThreadBinding usage to understand the relationship
rg "kThreadBinding" -B 3 -A 3 | head -100Repository: tile-ai/tilelang
Length of output: 1850
🏁 Script executed:
# Search for validation or assertions that check thread_binding consistency
rg "thread_binding.*loop_var\|loop_var.*thread_binding" -B 2 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Let's look at IterVar's structure more carefully to understand if it has a var field
rg "thread_binding->var\|thread_binding\(\)->var" -B 2 -A 2 | head -50Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Check how thread_binding IterVar is constructed and what fields it has
rg "IterVar.*thread" -B 1 -A 1 | head -100Repository: tile-ai/tilelang
Length of output: 6197
🏁 Script executed:
# Search for any validation or checks that compare thread_binding's var with loop_var
rg "\.var\b" -B 1 -A 1 | grep -B 2 -A 2 "thread_binding" | head -80Repository: tile-ai/tilelang
Length of output: 427
🏁 Script executed:
# Look for IterVarNode or IterVar field definitions
find . -name "*.h" -exec grep -l "IterVarNode\|IterVar {" {} \;Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Let's look for IterVar field access patterns more broadly
rg "->dom\|->iter_type\|->thread_tag\|->var" -B 1 -A 1 | grep -B 2 -A 2 "IterVar\|thread_binding" | head -120Repository: tile-ai/tilelang
Length of output: 464
🏁 Script executed:
# Check if there's any existing code that creates a new IterVar from an old one
rg "IterVar.*->dom\|IterVar.*thread_binding" -B 2 -A 2 | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Let's trace through how thread_binding is actually used in the For loop
rg "->thread_binding" -B 3 -A 3 | head -120Repository: tile-ai/tilelang
Length of output: 464
🏁 Script executed:
# Let's search for thread_binding access patterns
rg 'thread_binding\(' -B 2 -A 2 | head -100Repository: tile-ai/tilelang
Length of output: 6309
🏁 Script executed:
# Look for how For loop constructor is called with thread_binding
rg 'For\(' -A 5 | grep -B 3 -A 2 'thread_binding' | head -100Repository: tile-ai/tilelang
Length of output: 4383
🏁 Script executed:
# Check TVM documentation or examples by looking at how IterVar is created with different vars
rg 'IterVar\(' -A 1 | head -120Repository: tile-ai/tilelang
Length of output: 3850
Remap thread_binding for the else loop when thread_binding is defined.
When else_loop_var is created to maintain SSA form, the reuse of op->thread_binding creates IR inconsistency: the IterVar in thread_binding still references the original op->loop_var, but the else_loop uses else_loop_var. For thread-bound loops, code generation and validation expect loop_var and thread_binding's var to match. Create a mirrored IterVar for thread_binding when it's defined.
Suggested fix
// Create new loop_var for else_loop to maintain SSA form
Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype);
else_body = Substitute(else_body, {{op->loop_var, else_loop_var}});
+ IterVar else_thread_binding;
+ if (op->thread_binding.defined()) {
+ else_thread_binding = IterVar(op->thread_binding->dom, else_loop_var,
+ op->thread_binding->iter_type,
+ op->thread_binding->thread_tag);
+ }
+
For then_loop(op->loop_var, op->min, op->extent, op->kind, then_body,
op->thread_binding, op->annotations);
For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body,
- op->thread_binding, op->annotations);
+ else_thread_binding, op->annotations);🤖 Prompt for AI Agents
In `@src/transform/loop_unswitching.cc` around lines 306 - 313, The else loop
reuses op->thread_binding which still references the original loop var; create a
mirrored IterVar (or copy) whose var is else_loop_var when op->thread_binding is
non-null and use that mirrored IterVar in the else_loop construction so the
IterVar in thread_binding matches else_loop_var; locate the creation of Var
else_loop_var and the For else_loop(...) and replace op->thread_binding with a
remapped copy of op->thread_binding that references else_loop_var (preserving
other fields/annotations) before constructing else_loop.
- Added CallCheckerExcludingIf class to ensure function calls outside of hoisted if statements are identified, preventing potential synchronization issues during loop unswitching. - Updated loop unswitching logic to incorporate the new call checker, enhancing safety and correctness. - Integrated debug print statements in OptimizeForTarget to visualize module state before and after loop unswitching. - Disabled tilelang cache in the benchmark script for improved performance.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tilelang/engine/phase.py`:
- Around line 239-243: Remove the three debug print statements around the
LoopUnswitching pass in phase.py (the lines printing "Before StorageRewrite",
the module, and "After LoopUnswitching") because they pollute stdout and the
"Before StorageRewrite" label is misleading; either delete them or replace them
with conditional logging gated by the existing should_enable_ast_print() (or a
similar pass-context flag) and, if you keep a message, correct the label to
reflect the actual ordering (e.g., "After StorageRewrite" or "Before
LoopUnswitching") around the call to tilelang.transform.LoopUnswitching()(mod).
…nts in OptimizeForTarget function
|
@regression-perf |
- Introduced a new configuration option `tl.disable_loop_unswitching` to allow users to disable the loop unswitching optimization. - Updated the Loop Unswitching pass to check this configuration and return the original function if the option is enabled. - Added relevant documentation in the PassConfigKey enumeration for clarity.
Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.