|
| 1 | +# LetStmt Inlining in TileLang |
| 2 | + |
| 3 | +This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined. |
| 8 | + |
| 9 | +## When Does LetStmt Get Inlined? |
| 10 | + |
| 11 | +The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met: |
| 12 | + |
| 13 | +### 1. The value satisfies `CanInlineLetStmt` |
| 14 | + |
| 15 | +The `CanInlineLetStmt` helper returns `true` when: |
| 16 | + |
| 17 | +- **The value is a constant** (`is_const_number(op->value)` returns true) |
| 18 | +- **The value is a variable** (`op->value.as<VarNode>()` returns a node) |
| 19 | +- **The value is an integer expression without side effects**: |
| 20 | + - The value has `int` dtype |
| 21 | + - The side effect level is `kPure` or lower (no observable side effects) |
| 22 | + |
| 23 | +```cpp |
| 24 | +bool CanInlineLetStmt(const LetStmtNode *op) { |
| 25 | + if (is_const_number(op->value)) |
| 26 | + return true; |
| 27 | + if (op->value.as<VarNode>()) |
| 28 | + return true; |
| 29 | + // Won't face the deep expression explosion problem as in Let expression. |
| 30 | + // attempt to inline as much as possible if the value integer type(can be |
| 31 | + // index). |
| 32 | + if (!op->value.dtype().is_int()) |
| 33 | + return false; |
| 34 | + return SideEffect(op->value) <= CallEffectKind::kPure; |
| 35 | +} |
| 36 | +``` |
| 37 | +
|
| 38 | +### 2. The variable is NOT used in buffer definitions |
| 39 | +
|
| 40 | +Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields). |
| 41 | +
|
| 42 | +This protection exists because: |
| 43 | +- Buffer definitions are not updated during the simplification pass |
| 44 | +- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition |
| 45 | +- This would cause compilation errors or incorrect behavior |
| 46 | +
|
| 47 | +The mutator checks this before dropping the binding: |
| 48 | +
|
| 49 | +```cpp |
| 50 | +bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get()); |
| 51 | +
|
| 52 | +if (can_inline && !used_in_buffer_def) { |
| 53 | + return body; // Inline: remove LetStmt and return body directly |
| 54 | +} |
| 55 | +``` |
| 56 | + |
| 57 | +## Example: Why Buffer Definition Variables Are Protected |
| 58 | + |
| 59 | +Consider this code: |
| 60 | + |
| 61 | +```python |
| 62 | +let stride = M * 16 |
| 63 | +let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1]) |
| 64 | +buffer_a[i, j] = ... |
| 65 | +``` |
| 66 | + |
| 67 | +- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects) |
| 68 | +- However, `stride` is used in `buffer_a`'s `strides` field |
| 69 | +- If we inline it, the buffer definition becomes `strides=[M*16, 1]` |
| 70 | +- But the Buffer object's fields are not updated during simplification |
| 71 | +- Later code accessing `buffer_a` would fail to find the `stride` variable |
| 72 | + |
| 73 | +Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined. |
| 74 | + |
| 75 | +## How Variables Are Collected |
| 76 | + |
| 77 | +The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions: |
| 78 | + |
| 79 | +```cpp |
| 80 | +void VisitBuffer(const Buffer &buf) { |
| 81 | + // Collect variables that should remain defined |
| 82 | + VarUseDefAnalyzer usage(Array<Var>{}); |
| 83 | + usage(buf->data); |
| 84 | + for (const auto &dim : buf->shape) { |
| 85 | + usage(dim); |
| 86 | + } |
| 87 | + for (const auto &dim : buf->strides) { |
| 88 | + usage(dim); |
| 89 | + } |
| 90 | + usage(buf->elem_offset); |
| 91 | + |
| 92 | + // Track for use in LetStmtNode mutator |
| 93 | + for (const auto &var : usage.undefined_) { |
| 94 | + used_in_buffer_def_.insert(var.get()); |
| 95 | + } |
| 96 | +} |
| 97 | +``` |
| 98 | +
|
| 99 | +## Practical Example: Temporary Variable Issue |
| 100 | +
|
| 101 | +Consider this TileLang code: |
| 102 | +
|
| 103 | +```python |
| 104 | +for i in T.Parallel(block_N): |
| 105 | + idx = bx * block_N + i |
| 106 | + tmp = T.max(A[idx], 1) |
| 107 | + B[idx] = tmp / 2 |
| 108 | + A[idx] = tmp * 2 |
| 109 | +``` |
| 110 | + |
| 111 | +In this case: |
| 112 | +- `tmp` is an integer-like temporary variable |
| 113 | +- It satisfies `CanInlineLetStmt` (pure int expression) |
| 114 | +- It's **not** used in any buffer definition |
| 115 | +- Therefore, `tmp` **will be inlined** |
| 116 | + |
| 117 | +This means the IR becomes: |
| 118 | + |
| 119 | +```python |
| 120 | +for i in T.Parallel(block_N): |
| 121 | + idx = bx * block_N + i |
| 122 | + B[idx] = T.max(A[idx], 1) / 2 |
| 123 | + A[idx] = T.max(A[idx], 1) * 2 |
| 124 | +``` |
| 125 | + |
| 126 | +If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern. |
| 127 | + |
| 128 | +## Controlling Let Inlining via Pass Config |
| 129 | + |
| 130 | +TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments. |
| 131 | + |
| 132 | +```python |
| 133 | +from tilelang import transform |
| 134 | +from tilelang.engine.phase import LowerAndLegalize |
| 135 | + |
| 136 | +with transform.PassContext( |
| 137 | + config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True} |
| 138 | +): |
| 139 | + lowered_mod = LowerAndLegalize(input_mod, target) |
| 140 | +``` |
| 141 | + |
| 142 | +If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it. |
| 143 | + |
| 144 | +## Summary |
| 145 | + |
| 146 | +The LetStmt inlining mechanism is a **conservative optimization** that: |
| 147 | +1. Aggressively inlines simple, pure integer expressions to simplify the IR |
| 148 | +2. Protects variables used in buffer definitions to avoid breaking buffer access |
| 149 | +3. Helps reduce IR complexity and improve code generation |
| 150 | +4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required |
| 151 | + |
| 152 | +Understanding when inlining happens is crucial for: |
| 153 | +- Debugging compilation issues |
| 154 | +- Understanding generated code |
| 155 | +- Writing efficient TileLang programs |
| 156 | +- Identifying potential optimization opportunities or bugs |
| 157 | + |
| 158 | +## Related Files |
| 159 | + |
| 160 | +- `src/transform/simplify.cc`: Main Simplify implementation |
| 161 | +- `src/transform/frontend_legalize.cc`: Standalone LetInline pass |
| 162 | +- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining |
| 163 | +- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass |
0 commit comments