[Feature] Add hoist_broadcast_values pass#1606
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a new TIR pass HoistBroadcastValues that hoists immediate constants from Broadcast nodes into temporary Vars and inserts LetStmt definitions at BufferStore boundaries; the pass is exported and invoked in device-level lowering paths after simplification. (50 words) Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Lower as Lowering Pipeline
participant Pass as HoistBroadcastValues Pass
participant Mut as HoistBroadcastValuesMutator
participant Func as PrimFunc Body
participant BCast as Broadcast Node
participant Store as BufferStore Node
participant Result as Transformed Body
Lower->>Pass: instantiate pass (post-simplify)
Pass->>Mut: create mutator (pending_defs = [])
Pass->>Func: apply mutator to function body
Func->>BCast: visit Broadcast
BCast->>Mut: visit_broadcast_ — create Var, enqueue (Var, value)
Mut->>Func: replace immediate constant with Var
Func->>Store: visit BufferStore
Store->>Mut: visit_buffer_store_ — wrap stmt with LetStmt(s) for pending_defs
Mut->>Mut: clear pending_defs
Mut->>Result: return transformed body
Pass->>Lower: return updated prim_func
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom Pre-merge Checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Fix all issues with AI Agents 🤖
In @tilelang/transform/hoist_broadcast_values.py:
- Around line 34-54: visit_buffer_store_ currently only visits op.value and
skips op.indices, so Broadcasts in buffer indices won’t be hoisted; change the
visit to traverse indices (call self.visit_expr on each index or the helper used
for expressions) before constructing new_expr/new_stmt so pending_defs from
index visits get wrapped into LetStmt(s) as you already do for value, keeping
the reversed wrapping logic and clearing pending_defs afterward; also add
analogous visitor methods for LetStmt, Evaluate, IfThenElse, and AssertStmt
(implementations should clear pending_defs, visit their contained
expressions/statements so visit_broadcast_ can enqueue defs, then wrap with
LetStmt(s) like in visit_buffer_store_), and update the module docstring to
accurately state that visit_broadcast_ hoists all Broadcast values (not just
IntImm/FloatImm).
🧹 Nitpick comments (2)
tilelang/transform/__init__.py (1)
10-10: Remove unnecessarynoqadirective.Ruff reports the
noqa: F401directive is for a non-enabled rule. Since F401 isn't enforced in this project's configuration, the directive is redundant. Note that other imports in this file (lines 5, 6, 9) also have this pattern, so this could be addressed consistently across the file in a follow-up.tilelang/transform/hoist_broadcast_values.py (1)
76-79: Unusedmodandctxparameters are idiomatic for TVM passes.The Ruff warnings about unused
modandctxcan be suppressed by prefixing with underscores, which is the Python convention for intentionally unused parameters.🔎 Proposed fix
- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx): mutator = HoistBroadcastValuesMutator() new_body = mutator.visit_stmt(func.body) return func.with_body(new_body)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/engine/lower.pytilelang/transform/__init__.pytilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/transform/__init__.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
HoistBroadcastValues(57-81)
tilelang/engine/lower.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
HoistBroadcastValues(57-81)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
LetStmt(874-902)tilelang/language/v2/builder.py (1)
PrimFunc(690-699)
🪛 Ruff (0.14.10)
tilelang/transform/__init__.py
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/transform/hoist_broadcast_values.py
76-76: Unused function argument: mod
(ARG001)
76-76: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
tilelang/engine/lower.py (2)
191-205: LGTM!The
HoistBroadcastValuespass is correctly integrated afterSimplify()and before target-specific codegen. This ordering ensures the pass operates on simplified IR.
208-230: LGTM!The integration in
device_codegen_without_compilemirrors the placement indevice_codegen, maintaining consistency across both code paths.
There was a problem hiding this comment.
Actionable comments posted: 1
Fix all issues with AI Agents 🤖
In @tilelang/transform/hoist_broadcast_values.py:
- Around line 20-34: The hoisted broadcast variables currently all use the
literal name "broadcast_var" in visit_broadcast_, causing collisions when
multiple broadcasts are hoisted; change the naming to generate a unique
identifier per hoist (e.g., use a running counter like self._broadcast_counter
or incorporate a unique suffix) when creating new_var, increment the counter
each time you create a new Var, and keep the rest of the logic (appending to
self.pending_defs and returning Broadcast(new_var, op.lanes)) unchanged so each
hoisted value gets a unique variable name.
♻️ Duplicate comments (1)
tilelang/transform/hoist_broadcast_values.py (1)
62-79: Update the docstring example to match the actual implementation.The example transformation shows variables named
bv_3_14andbv_3_14_1, but the implementation creates variables namedbroadcast_var(orbroadcast_var_0,broadcast_var_1, etc., after fixing the naming collision). Update the example to reflect the actual variable naming scheme.🔎 Proposed docstring correction
Example Transformation: ----------------------- Before: A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4) After: - bv_3_14 = 3.14 - bv_3_14_1 = 3.14 - A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4) + broadcast_var_0 = 3.14 + broadcast_var_1 = 3.14 + A[i] = B[i] + T.Broadcast(broadcast_var_0, 4) + T.Broadcast(broadcast_var_1, 4)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
tilelang/transform/hoist_broadcast_values.py
🧬 Code graph analysis (1)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
LetStmt(874-902)tilelang/language/v2/builder.py (1)
PrimFunc(690-699)
🪛 Ruff (0.14.10)
tilelang/transform/hoist_broadcast_values.py
81-81: Unused function argument: mod
(ARG001)
81-81: Unused function argument: ctx
(ARG001)
🔇 Additional comments (1)
tilelang/transform/hoist_broadcast_values.py (1)
38-59: Indices are now properly visited.The implementation correctly visits both indices (lines 43-46) and the value expression (line 47), ensuring that any Broadcast nodes in buffer access patterns are hoisted. The reversed wrapping logic (line 54) properly nests the LetStmt definitions.
| def visit_broadcast_(self, op): | ||
| if isinstance(op.value, (tir.IntImm, tir.FloatImm)): | ||
| # 1. Intercept Broadcast nodes. | ||
| # Extract the value to be hoisted into a variable. | ||
| val = self.visit_expr(op.value) | ||
| # 2. Create a new variable. | ||
| new_var = Var("broadcast_var", dtype=val.dtype) | ||
|
|
||
| # 3. Add the (variable, value) pair to the pending queue. | ||
| # Note: Do not create the LetStmt here; it must wrap the statement. | ||
| self.pending_defs.append((new_var, val)) | ||
|
|
||
| # 4. Return a new Broadcast node, using the new variable to replace the original value. | ||
| return Broadcast(new_var, op.lanes) | ||
| return op |
There was a problem hiding this comment.
Fix variable naming collisions by generating unique names.
All hoisted broadcasts use the identical variable name "broadcast_var" (line 26), causing name collisions when multiple broadcasts appear in the same statement. Each hoisted value requires a unique identifier.
🔎 Proposed fix using a counter
@tir.functor.mutator
class HoistBroadcastValuesMutator(PyStmtExprMutator):
def __init__(self):
super().__init__()
# Temporary queue: used to store variables that need to be defined within the current statement.
self.pending_defs = []
+ # Counter for generating unique variable names
+ self.var_counter = 0
def visit_broadcast_(self, op):
if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
# 1. Intercept Broadcast nodes.
# Extract the value to be hoisted into a variable.
val = self.visit_expr(op.value)
# 2. Create a new variable.
- new_var = Var("broadcast_var", dtype=val.dtype)
+ new_var = Var(f"broadcast_var_{self.var_counter}", dtype=val.dtype)
+ self.var_counter += 1
# 3. Add the (variable, value) pair to the pending queue.
# Note: Do not create the LetStmt here; it must wrap the statement.
self.pending_defs.append((new_var, val))
# 4. Return a new Broadcast node, using the new variable to replace the original value.
return Broadcast(new_var, op.lanes)
return op🤖 Prompt for AI Agents
In @tilelang/transform/hoist_broadcast_values.py around lines 20-34, The hoisted
broadcast variables currently all use the literal name "broadcast_var" in
visit_broadcast_, causing collisions when multiple broadcasts are hoisted;
change the naming to generate a unique identifier per hoist (e.g., use a running
counter like self._broadcast_counter or incorporate a unique suffix) when
creating new_var, increment the counter each time you create a new Var, and keep
the rest of the logic (appending to self.pending_defs and returning
Broadcast(new_var, op.lanes)) unchanged so each hoisted value gets a unique
variable name.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (3)
tilelang/transform/hoist_broadcast_values.py (3)
20-34: CRITICAL: Variable naming collision remains unfixed.Line 26 still uses the hardcoded name
"broadcast_var"for all hoisted broadcasts, causing name collisions when multiple broadcasts appear in the same statement. This issue was previously flagged but not resolved.🔎 Proposed fix (as suggested in previous review)
Add a counter to
__init__:def __init__(self): super().__init__() # Temporary queue: used to store variables that need to be defined within the current statement. self.pending_defs = [] + # Counter for generating unique variable names + self.var_counter = 0Then update line 26 to use unique names:
- new_var = Var("broadcast_var", dtype=val.dtype) + new_var = Var(f"broadcast_var_{self.var_counter}", dtype=val.dtype) + self.var_counter += 1
36-56: MAJOR: Missing handlers for other statement types.Lines 36-37 note that "LetStmt, Evaluate, IfThenElse, AssertStmt" should be intercepted, but only
visit_buffer_store_is implemented. Broadcasts within expressions in these other statement types will not be hoisted, limiting the pass's effectiveness.🔎 Implementation guidance
Add analogous visitor methods following the same pattern as
visit_buffer_store_:def visit_evaluate_(self, op): self.pending_defs = [] new_value = self.visit_expr(op.value) new_stmt = tir.Evaluate(new_value) if self.pending_defs: for var, val in reversed(self.pending_defs): new_stmt = LetStmt(var, val, new_stmt) self.pending_defs = [] return new_stmt def visit_let_stmt_(self, op): self.pending_defs = [] new_value = self.visit_expr(op.value) new_body = self.visit_stmt(op.body) new_stmt = LetStmt(op.var, new_value, new_body) if self.pending_defs: for var, val in reversed(self.pending_defs): new_stmt = LetStmt(var, val, new_stmt) self.pending_defs = [] return new_stmtSimilar handlers needed for
visit_if_then_else_andvisit_assert_stmt_.
63-76: MINOR: Docstring example doesn't match implementation.The example shows variable names like
bv_3_14andbv_3_14_1, but the implementation uses"broadcast_var"(line 26). Once the variable naming collision is fixed, update this example to reflect the actual naming scheme used.
🧹 Nitpick comments (1)
tilelang/transform/hoist_broadcast_values.py (1)
78-78: Optional: Prefix unused parameters with underscore.Parameters
modandctxare part of theprim_func_passsignature but unused in the implementation. Consider prefixing them with underscores to indicate intentional non-use:-def pass_fn(func: PrimFunc, mod, ctx): +def pass_fn(func: PrimFunc, _mod, _ctx):
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
tilelang/transform/hoist_broadcast_values.py
🧬 Code graph analysis (1)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
LetStmt(874-902)tilelang/language/v2/builder.py (1)
PrimFunc(690-699)
🪛 Ruff (0.14.10)
tilelang/transform/hoist_broadcast_values.py
78-78: Unused function argument: mod
(ARG001)
78-78: Unused function argument: ctx
(ARG001)
|
@codex review |
|
Codex review is not enabled for this repo. Please contact the admins of this repo to enable Codex. |
LeiWang1999
left a comment
There was a problem hiding this comment.
good to go to provide an example for this pass in testing/python/transform
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI Agents
In @testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:
- Line 30: The literal 430 assigned in the test (E[i] = 430) is not
representable in the float8_e8m0fnu format; change the broadcast value used in
test_tilelang_transform_hoist_broadcast_values.py (the E[i] = 430 assignment) to
a power-of-two value that all target dtypes can represent (e.g., 256 or 128) so
the test data is consistent across float8 variants while preserving the
kernel-source pattern being validated.
🧹 Nitpick comments (3)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (3)
9-10: Consider a more descriptive function name.The function name
qwqis unclear. Consider renaming to something more descriptive likebroadcast_test_kernelorhoist_broadcast_kernelto better convey the test's purpose.🔎 Proposed refactor
@tilelang.jit -def qwq(dtype=torch.float8_e4m3fn): +def broadcast_test_kernel(dtype=torch.float8_e4m3fn):And update the call site:
- kernel = qwq(dtype) + kernel = broadcast_test_kernel(dtype)
40-41: Consider documenting the expected kernel source pattern.The regex pattern checks for type-consistent broadcast variable declarations but relies on specific code generation format. Consider adding a comment explaining the expected pattern (e.g.,
float broadcast_var_1 = float(value);) to make the test more maintainable if code generation changes.🔎 Proposed documentation improvement
+ # Expected pattern in kernel source: e.g., "float broadcast_var_1 = float(13.5);" + # The regex verifies that 4 immediate constants were hoisted into broadcast_var declarations matches = re.findall(r"(\w+) broadcast_var(_[0-9]+)? = \1", kernel.get_kernel_source()) assert len(matches) == 4
42-47: Consider adding output validation.The test currently validates that the kernel compiles and executes without crashing, but doesn't verify the correctness of the output values. While this may be intentional for a compiler pass test focused on code generation, adding basic output validation would strengthen confidence that the hoisted broadcast values produce correct results.
🔎 Optional enhancement for output validation
a = torch.empty((32,), device="cuda", dtype=dtype) b = torch.empty((16,), device="cuda", dtype=dtype) c = torch.empty((8,), device="cuda", dtype=dtype) d = torch.empty((4,), device="cuda", dtype=dtype) e = torch.empty((2,), device="cuda", dtype=dtype) kernel(a, b, c, d, e) + + # Validate that broadcast values were correctly written + # Note: float8 conversions may introduce precision differences + assert torch.allclose(a.float(), torch.tensor(1.0).expand(32), rtol=1e-2) + assert torch.allclose(b.float(), torch.tensor(13.5).expand(16), rtol=1e-2) + assert torch.allclose(c.float(), torch.tensor(3.14).expand(8), rtol=1e-2) + assert torch.allclose(d.float(), torch.tensor(2.72).expand(4), rtol=1e-2) + assert torch.allclose(e.float(), torch.tensor(430.0).expand(2), rtol=1e-2)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
🧬 Code graph analysis (1)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (2)
tilelang/language/v2/dtypes.py (4)
float8_e4m3fn(341-341)float8_e5m2(355-355)float8_e8m0fnu(369-369)float16(299-299)tilelang/language/proxy.py (1)
Tensor(232-232)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (2)
1-6: LGTM!The imports are appropriate for testing a compiler pass that generates CUDA kernels.
50-51: LGTM!Standard test entry point consistent with the testing framework.
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
Show resolved
Hide resolved
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/transform/__init__.py (1)
10-10: LGTM! New pass export is correctly implemented.The import properly exports
HoistBroadcastValuesto the public API and follows the same pattern as other re-exported transformations in this file.Minor note: The static analysis tool suggests the
# noqa: F401directive may be unnecessary since F401 is not currently enabled in the linting configuration. However, keeping it maintains consistency with lines 5-9 and doesn't cause any issues.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/engine/lower.pytilelang/transform/__init__.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/engine/lower.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
tilelang/transform/__init__.py
🧬 Code graph analysis (1)
tilelang/transform/__init__.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
HoistBroadcastValues(59-83)
🪛 Ruff (0.14.10)
tilelang/transform/__init__.py
10-10: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
Solve #1601
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.