Skip to content

Commit f8ae600

Browse files
authored
[Bugfix] Do not force inline let stmt (#947)
* remove debug print * Remove inline let expressions from the LowerAndLegalize function in phase.py * add test * Update sparse MLA examples to support SKV adjustment and correctness checks - Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests. - Added check_correctness parameter to test functions for validation of outputs. - Updated test cases to reflect new SKV values and correctness checks. * reduce test shape * Update documentation structure and refactor main function parameters in example_fusedmoe_tilelang.py - Added a new section for compiler internals in the documentation. - Refactored the main function in example_fusedmoe_tilelang.py to accept parameters for hidden dimensions, expert configurations, and batch/sequence sizes, improving flexibility and readability. * Update buffer access checks in merge_shared_memory_allocations.cc - Changed the condition for buffer access from less than (<) to less than or equal to (<=) to allow access at the same scope level. - Adjusted the logic for determining the access level when touching buffers to ensure correct handling of scope levels. * lint fix * Support pipeline with LetStmt * lint fix * • Fix LowerTileOp let handling to avoid LetInline dependency - inline let-bound BufferLoad nodes via resolver helpers and structured return - remap layouts/buffers using original data vars and only rewrite when needed - update pipeline planner to understand let-bound address_of buffers - document the new inline behaviour in docs/let_inline_fix.md * fix for wgmma pipeline with let binding * lint fix * test fix * reduce smem usage. * let binding enhancement * fix for dpgm * fix simplify * lint fix * use tilelang.Simplify instead of tir.Simplify * • Add TL_FORCE_LET_INLINE pass config and gate eager LetInline usage - register the new config in builtin headers/registration - add helper to pipeline enabling LetInline based on pass context - document LetStmt inlining controls and usage
1 parent 7cd0da9 commit f8ae600

File tree

17 files changed

+804
-98
lines changed

17 files changed

+804
-98
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# LetStmt Inlining in TileLang
2+
3+
This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance.
4+
5+
## Overview
6+
7+
A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined.
8+
9+
## When Does LetStmt Get Inlined?
10+
11+
The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met:
12+
13+
### 1. The value satisfies `CanInlineLetStmt`
14+
15+
The `CanInlineLetStmt` helper returns `true` when:
16+
17+
- **The value is a constant** (`is_const_number(op->value)` returns true)
18+
- **The value is a variable** (`op->value.as<VarNode>()` returns a node)
19+
- **The value is an integer expression without side effects**:
20+
- The value has `int` dtype
21+
- The side effect level is `kPure` or lower (no observable side effects)
22+
23+
```cpp
24+
bool CanInlineLetStmt(const LetStmtNode *op) {
25+
if (is_const_number(op->value))
26+
return true;
27+
if (op->value.as<VarNode>())
28+
return true;
29+
// Won't face the deep expression explosion problem as in Let expression.
30+
// attempt to inline as much as possible if the value integer type(can be
31+
// index).
32+
if (!op->value.dtype().is_int())
33+
return false;
34+
return SideEffect(op->value) <= CallEffectKind::kPure;
35+
}
36+
```
37+
38+
### 2. The variable is NOT used in buffer definitions
39+
40+
Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields).
41+
42+
This protection exists because:
43+
- Buffer definitions are not updated during the simplification pass
44+
- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition
45+
- This would cause compilation errors or incorrect behavior
46+
47+
The mutator checks this before dropping the binding:
48+
49+
```cpp
50+
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
51+
52+
if (can_inline && !used_in_buffer_def) {
53+
return body; // Inline: remove LetStmt and return body directly
54+
}
55+
```
56+
57+
## Example: Why Buffer Definition Variables Are Protected
58+
59+
Consider this code:
60+
61+
```python
62+
let stride = M * 16
63+
let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1])
64+
buffer_a[i, j] = ...
65+
```
66+
67+
- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects)
68+
- However, `stride` is used in `buffer_a`'s `strides` field
69+
- If we inline it, the buffer definition becomes `strides=[M*16, 1]`
70+
- But the Buffer object's fields are not updated during simplification
71+
- Later code accessing `buffer_a` would fail to find the `stride` variable
72+
73+
Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined.
74+
75+
## How Variables Are Collected
76+
77+
The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions:
78+
79+
```cpp
80+
void VisitBuffer(const Buffer &buf) {
81+
// Collect variables that should remain defined
82+
VarUseDefAnalyzer usage(Array<Var>{});
83+
usage(buf->data);
84+
for (const auto &dim : buf->shape) {
85+
usage(dim);
86+
}
87+
for (const auto &dim : buf->strides) {
88+
usage(dim);
89+
}
90+
usage(buf->elem_offset);
91+
92+
// Track for use in LetStmtNode mutator
93+
for (const auto &var : usage.undefined_) {
94+
used_in_buffer_def_.insert(var.get());
95+
}
96+
}
97+
```
98+
99+
## Practical Example: Temporary Variable Issue
100+
101+
Consider this TileLang code:
102+
103+
```python
104+
for i in T.Parallel(block_N):
105+
idx = bx * block_N + i
106+
tmp = T.max(A[idx], 1)
107+
B[idx] = tmp / 2
108+
A[idx] = tmp * 2
109+
```
110+
111+
In this case:
112+
- `tmp` is an integer-like temporary variable
113+
- It satisfies `CanInlineLetStmt` (pure int expression)
114+
- It's **not** used in any buffer definition
115+
- Therefore, `tmp` **will be inlined**
116+
117+
This means the IR becomes:
118+
119+
```python
120+
for i in T.Parallel(block_N):
121+
idx = bx * block_N + i
122+
B[idx] = T.max(A[idx], 1) / 2
123+
A[idx] = T.max(A[idx], 1) * 2
124+
```
125+
126+
If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.
127+
128+
## Controlling Let Inlining via Pass Config
129+
130+
TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments.
131+
132+
```python
133+
from tilelang import transform
134+
from tilelang.engine.phase import LowerAndLegalize
135+
136+
with transform.PassContext(
137+
config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True}
138+
):
139+
lowered_mod = LowerAndLegalize(input_mod, target)
140+
```
141+
142+
If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it.
143+
144+
## Summary
145+
146+
The LetStmt inlining mechanism is a **conservative optimization** that:
147+
1. Aggressively inlines simple, pure integer expressions to simplify the IR
148+
2. Protects variables used in buffer definitions to avoid breaking buffer access
149+
3. Helps reduce IR complexity and improve code generation
150+
4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required
151+
152+
Understanding when inlining happens is crucial for:
153+
- Debugging compilation issues
154+
- Understanding generated code
155+
- Writing efficient TileLang programs
156+
- Identifying potential optimization opportunities or bugs
157+
158+
## Related Files
159+
160+
- `src/transform/simplify.cc`: Main Simplify implementation
161+
- `src/transform/frontend_legalize.cc`: Standalone LetInline pass
162+
- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining
163+
- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass

docs/index.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ deeplearning_operators/matmul
3535
deeplearning_operators/deepseek_mla
3636
:::
3737

38+
:::{toctree}
39+
:maxdepth: 1
40+
:caption: COMPILER INTERNALS
41+
42+
compiler_internals/letstmt_inline
43+
:::
44+
3845
:::{toctree}
3946
:maxdepth: 1
4047
:caption: API Reference

examples/blocksparse_attention/test_example_blocksparse_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def test_example_tilelang_block_sparse_attn():
1616

1717

1818
def test_example_tilelang_sparse_gqa_decode_varlen_indice():
19-
example_tilelang_sparse_gqa_decode_varlen_indice.main()
19+
example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048)
2020

2121

2222
def test_example_tilelang_sparse_gqa_decode_varlen_mask():
23-
example_tilelang_sparse_gqa_decode_varlen_mask.main()
23+
example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048)
2424

2525

2626
def test_example_triton_sparse_gqa_decode_varlen_indice():

examples/fusedmoe/example_fusedmoe_tilelang.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,21 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
521521
return output
522522

523523

524-
def main():
524+
def main(d_hidden=7168,
525+
d_expert=2048,
526+
n_routed_experts=8,
527+
n_shared_experts=1,
528+
n_experts_per_token=4,
529+
batch_size=1,
530+
seq_len=8192):
525531
config = {
526-
"dhidden": 7168,
527-
"dexpert": 2048,
528-
"nroutedexperts": 8,
529-
"nsharedexperts": 1,
530-
"nexpertspertoken": 4,
531-
"bs": 1,
532-
"seqlen": 8192,
532+
"dhidden": d_hidden,
533+
"dexpert": d_expert,
534+
"nroutedexperts": n_routed_experts,
535+
"nsharedexperts": n_shared_experts,
536+
"nexpertspertoken": n_experts_per_token,
537+
"bs": batch_size,
538+
"seqlen": seq_len,
533539
"seed": 81394
534540
}
535541

examples/fusedmoe/test_example_fusedmoe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33

44

55
def test_example_fusedmoe_tilelang():
6-
example_fusedmoe_tilelang.main()
6+
example_fusedmoe_tilelang.main(
7+
d_hidden=1024,
8+
d_expert=256,
9+
n_routed_experts=8,
10+
n_shared_experts=1,
11+
n_experts_per_token=4,
12+
batch_size=1,
13+
seq_len=1024)
714

815

916
if __name__ == "__main__":

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
2525
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
2626
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
2727
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
28+
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
2829
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
2930
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
3031
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);

src/op/builtin.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ static constexpr const char *kDisableDynamicTailSplit =
7171
static constexpr const char *kDisableThreadStorageSync =
7272
"tl.disable_thread_storage_sync";
7373

74+
/*!
75+
* \brief Force inline Let bindings during simplification.
76+
*
77+
* kForceLetInline = "tl.force_let_inline"
78+
*
79+
*/
80+
static constexpr const char *kForceLetInline = "tl.force_let_inline";
81+
7482
/*!
7583
* \brief The size of the vectorized dimension in buffer, designed by user
7684
*
@@ -441,4 +449,4 @@ TVM_DLL const Op &increase_descriptor_offset();
441449
} // namespace tl
442450
} // namespace tvm
443451

444-
#endif // TVM_TL_OP_BUILTIN_H_
452+
#endif // TVM_TL_OP_BUILTIN_H_

src/transform/inject_pipeline.cc

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/tir/builtin.h>
2727
#include <tvm/tir/transform.h>
2828

29+
#include <functional>
2930
#include <unordered_set>
3031
#include <utility>
3132

@@ -845,24 +846,77 @@ class PipelineInjector : private StmtExprMutator {
845846
// Step 2: Find the body and buffer allocations of the pipeline. The body
846847
// can be direct child of the for-loop. If the for-loop has BlockRealize as
847848
// its child, the pipeline body will be the child of the block.
848-
Stmt pipeline_body{nullptr};
849+
Stmt pipeline_body_root{nullptr};
850+
bool pipeline_body_from_block = false;
849851
Array<Buffer> pipeline_allocs;
850852
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
851853
const auto &block = realize->block;
852854
for (const auto &buffer : block->alloc_buffers) {
853855
ICHECK(buffer->IsInstance<BufferNode>());
854856
buffer_data_to_buffer_.Set(buffer->data, buffer);
855857
}
856-
pipeline_body = block->body;
858+
pipeline_body_root = block->body;
857859
pipeline_allocs = block->alloc_buffers;
860+
pipeline_body_from_block = true;
858861
} else {
859-
pipeline_body = for_node->body;
862+
pipeline_body_root = for_node->body;
860863
}
861864

862-
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
863-
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
864-
"should be SeqStmt, got "
865-
<< pipeline_body->GetTypeKey();
865+
const SeqStmtNode *pipeline_body_seq = nullptr;
866+
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
867+
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
868+
ObjectRef node = attr->node;
869+
String attr_key = attr->attr_key;
870+
PrimExpr value = attr->value;
871+
Span span = attr->span;
872+
rewrap_fns.emplace_back(
873+
[node = std::move(node), attr_key = std::move(attr_key),
874+
value = std::move(value), span](Stmt body) -> Stmt {
875+
return AttrStmt(node, attr_key, value, body, span);
876+
});
877+
};
878+
{
879+
Stmt current = pipeline_body_root;
880+
while (true) {
881+
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
882+
pipeline_body_seq = seq_stmt;
883+
break;
884+
}
885+
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
886+
ICHECK(!if_then_else->else_case.defined())
887+
<< "InjectSoftwarePipeline: Can't handle the body of the loop "
888+
"because the IfThenElse node has an else branch";
889+
PrimExpr condition = if_then_else->condition;
890+
Span span = if_then_else->span;
891+
rewrap_fns.emplace_back(
892+
[condition = std::move(condition), span](Stmt body) -> Stmt {
893+
return IfThenElse(condition, body, Stmt(), span);
894+
});
895+
current = if_then_else->then_case;
896+
continue;
897+
}
898+
if (const auto *let_stmt = current.as<LetStmtNode>()) {
899+
Var var = let_stmt->var;
900+
PrimExpr value = let_stmt->value;
901+
Span span = let_stmt->span;
902+
rewrap_fns.emplace_back([var = std::move(var),
903+
value = std::move(value),
904+
span](Stmt body) -> Stmt {
905+
return LetStmt(var, value, body, span);
906+
});
907+
current = let_stmt->body;
908+
continue;
909+
}
910+
if (const auto *attr = current.as<AttrStmtNode>()) {
911+
append_attr_wrapper(attr);
912+
current = attr->body;
913+
continue;
914+
}
915+
LOG(FATAL) << "ValueError: The body of the software pipeline should be "
916+
<< "SeqStmt, got " << current->GetTypeKey();
917+
}
918+
}
919+
ICHECK(pipeline_body_seq != nullptr);
866920

867921
// Step 3: Blockize the components of the pipeline. Each child of the
868922
// pipelined loop will be converted into a block.
@@ -934,6 +988,27 @@ class PipelineInjector : private StmtExprMutator {
934988
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
935989
GetRef<For>(op), pipeline_info)
936990
.BuildPipeline();
991+
auto apply_wrappers = [&](Stmt stmt) {
992+
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
993+
stmt = (*it)(stmt);
994+
}
995+
return stmt;
996+
};
997+
if (!rewrap_fns.empty()) {
998+
if (pipeline_body_from_block) {
999+
BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
1000+
Block pipeline_block = pipeline_realize->block;
1001+
{
1002+
BlockNode *block_node = pipeline_block.CopyOnWrite();
1003+
block_node->body = apply_wrappers(block_node->body);
1004+
}
1005+
pipeline = BlockRealize(pipeline_realize->iter_values,
1006+
pipeline_realize->predicate, pipeline_block,
1007+
pipeline_realize->span);
1008+
} else {
1009+
pipeline = apply_wrappers(pipeline);
1010+
}
1011+
}
9371012

9381013
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
9391014
const auto &block = realize->block;

0 commit comments

Comments
 (0)