Skip to content

[Feature] Add hoist_broadcast_values pass#1606

Merged
LeiWang1999 merged 7 commits intotile-ai:mainfrom
silentCoder-dev:hoist-broadcast
Jan 7, 2026
Merged

[Feature] Add hoist_broadcast_values pass#1606
LeiWang1999 merged 7 commits intotile-ai:mainfrom
silentCoder-dev:hoist-broadcast

Conversation

@silentCoder-dev
Copy link
Collaborator

@silentCoder-dev silentCoder-dev commented Jan 5, 2026

Solve #1601

Summary by CodeRabbit

  • New Features
    • Hoists constant values out of broadcast operations to improve generated kernel clarity; the transformation is now exposed for inclusion in transform pipelines and applied in more codegen paths.
  • Tests
    • Added tests verifying broadcast hoisting across multiple numeric types (including fp8/float16) on CUDA and a regression check for a related broadcast issue.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

github-actions bot commented Jan 5, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 5, 2026

📝 Walkthrough

Walkthrough

Adds a new TIR pass HoistBroadcastValues that hoists immediate constants from Broadcast nodes into temporary Vars and inserts LetStmt definitions at BufferStore boundaries; the pass is exported and invoked in device-level lowering paths after simplification. (50 words)

Changes

Cohort / File(s) Summary
New Transform Pass
tilelang/transform/hoist_broadcast_values.py
Implements HoistBroadcastValuesMutator (subclass of PyStmtExprMutator) that replaces immediate constants in Broadcast with new Vars, accumulates (var, value) pairs, and wraps statements with LetStmt at BufferStore boundaries. Exposes HoistBroadcastValues() prim_func_pass.
Transform Exports
tilelang/transform/__init__.py
Adds public export HoistBroadcastValues by importing it from .hoist_broadcast_values.
Lowering Integration
tilelang/engine/lower.py
Applies tilelang.transform.HoistBroadcastValues() in device_codegen and device_codegen_without_compile paths after the simplification pass.
Tests — Transform
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
Adds tests exercising the pass: kernel generation for multiple dtypes, assertion on broadcast-var patterns in generated source, unit test comparing before/after TIR functions using the pass.
Tests — Issue Repro
testing/python/issue/test_tilelang_issue_1601.py
Adds a regression test asserting a specific broadcast initialization string for an fp8 kernel.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Lower as Lowering Pipeline
    participant Pass as HoistBroadcastValues Pass
    participant Mut as HoistBroadcastValuesMutator
    participant Func as PrimFunc Body
    participant BCast as Broadcast Node
    participant Store as BufferStore Node
    participant Result as Transformed Body

    Lower->>Pass: instantiate pass (post-simplify)
    Pass->>Mut: create mutator (pending_defs = [])
    Pass->>Func: apply mutator to function body
    Func->>BCast: visit Broadcast
    BCast->>Mut: visit_broadcast_ — create Var, enqueue (Var, value)
    Mut->>Func: replace immediate constant with Var
    Func->>Store: visit BufferStore
    Store->>Mut: visit_buffer_store_ — wrap stmt with LetStmt(s) for pending_defs
    Mut->>Mut: clear pending_defs
    Mut->>Result: return transformed body
    Pass->>Lower: return updated prim_func
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

🐇 I nudged constants from Broadcasts into rows of Vars,

Tucked LetStmts by stores like tiny jars.
Tiny hops in lowering, bindings neat and small,
Kernels wake tidier, I bounced through it all.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding a new compiler pass called hoist_broadcast_values. It is concise, clear, and directly related to the primary feature being introduced.

✏️ Tip: You can configure your own custom Pre-merge Checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Fix all issues with AI Agents 🤖
In @tilelang/transform/hoist_broadcast_values.py:
- Around line 34-54: visit_buffer_store_ currently only visits op.value and
skips op.indices, so Broadcasts in buffer indices won’t be hoisted; change the
visit to traverse indices (call self.visit_expr on each index or the helper used
for expressions) before constructing new_expr/new_stmt so pending_defs from
index visits get wrapped into LetStmt(s) as you already do for value, keeping
the reversed wrapping logic and clearing pending_defs afterward; also add
analogous visitor methods for LetStmt, Evaluate, IfThenElse, and AssertStmt
(implementations should clear pending_defs, visit their contained
expressions/statements so visit_broadcast_ can enqueue defs, then wrap with
LetStmt(s) like in visit_buffer_store_), and update the module docstring to
accurately state that visit_broadcast_ hoists all Broadcast values (not just
IntImm/FloatImm).
🧹 Nitpick comments (2)
tilelang/transform/__init__.py (1)

10-10: Remove unnecessary noqa directive.

Ruff reports the noqa: F401 directive is for a non-enabled rule. Since F401 isn't enforced in this project's configuration, the directive is redundant. Note that other imports in this file (lines 5, 6, 9) also have this pattern, so this could be addressed consistently across the file in a follow-up.

tilelang/transform/hoist_broadcast_values.py (1)

76-79: Unused mod and ctx parameters are idiomatic for TVM passes.

The Ruff warnings about unused mod and ctx can be suppressed by prefixing with underscores, which is the Python convention for intentionally unused parameters.

🔎 Proposed fix
-    def pass_fn(func: PrimFunc, mod, ctx):
+    def pass_fn(func: PrimFunc, _mod, _ctx):
         mutator = HoistBroadcastValuesMutator()
         new_body = mutator.visit_stmt(func.body)
         return func.with_body(new_body)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32aec8a and 00d2cf4.

📒 Files selected for processing (3)
  • tilelang/engine/lower.py
  • tilelang/transform/__init__.py
  • tilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/transform/__init__.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
  • HoistBroadcastValues (57-81)
tilelang/engine/lower.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
  • HoistBroadcastValues (57-81)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
  • LetStmt (874-902)
tilelang/language/v2/builder.py (1)
  • PrimFunc (690-699)
🪛 Ruff (0.14.10)
tilelang/transform/__init__.py

10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/transform/hoist_broadcast_values.py

76-76: Unused function argument: mod

(ARG001)


76-76: Unused function argument: ctx

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
tilelang/engine/lower.py (2)

191-205: LGTM!

The HoistBroadcastValues pass is correctly integrated after Simplify() and before target-specific codegen. This ordering ensures the pass operates on simplified IR.


208-230: LGTM!

The integration in device_codegen_without_compile mirrors the placement in device_codegen, maintaining consistency across both code paths.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Fix all issues with AI Agents 🤖
In @tilelang/transform/hoist_broadcast_values.py:
- Around line 20-34: The hoisted broadcast variables currently all use the
literal name "broadcast_var" in visit_broadcast_, causing collisions when
multiple broadcasts are hoisted; change the naming to generate a unique
identifier per hoist (e.g., use a running counter like self._broadcast_counter
or incorporate a unique suffix) when creating new_var, increment the counter
each time you create a new Var, and keep the rest of the logic (appending to
self.pending_defs and returning Broadcast(new_var, op.lanes)) unchanged so each
hoisted value gets a unique variable name.
♻️ Duplicate comments (1)
tilelang/transform/hoist_broadcast_values.py (1)

62-79: Update the docstring example to match the actual implementation.

The example transformation shows variables named bv_3_14 and bv_3_14_1, but the implementation creates variables named broadcast_var (or broadcast_var_0, broadcast_var_1, etc., after fixing the naming collision). Update the example to reflect the actual variable naming scheme.

🔎 Proposed docstring correction
     Example Transformation:
     -----------------------
     Before:
         A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4)
 
     After:
-        bv_3_14 = 3.14
-        bv_3_14_1 = 3.14
-        A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4)
+        broadcast_var_0 = 3.14
+        broadcast_var_1 = 3.14
+        A[i] = B[i] + T.Broadcast(broadcast_var_0, 4) + T.Broadcast(broadcast_var_1, 4)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 00d2cf4 and e8d5f07.

📒 Files selected for processing (1)
  • tilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • tilelang/transform/hoist_broadcast_values.py
🧬 Code graph analysis (1)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
  • LetStmt (874-902)
tilelang/language/v2/builder.py (1)
  • PrimFunc (690-699)
🪛 Ruff (0.14.10)
tilelang/transform/hoist_broadcast_values.py

81-81: Unused function argument: mod

(ARG001)


81-81: Unused function argument: ctx

(ARG001)

🔇 Additional comments (1)
tilelang/transform/hoist_broadcast_values.py (1)

38-59: Indices are now properly visited.

The implementation correctly visits both indices (lines 43-46) and the value expression (line 47), ensuring that any Broadcast nodes in buffer access patterns are hoisted. The reversed wrapping logic (line 54) properly nests the LetStmt definitions.

Comment on lines 20 to 34
def visit_broadcast_(self, op):
if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
# 1. Intercept Broadcast nodes.
# Extract the value to be hoisted into a variable.
val = self.visit_expr(op.value)
# 2. Create a new variable.
new_var = Var("broadcast_var", dtype=val.dtype)

# 3. Add the (variable, value) pair to the pending queue.
# Note: Do not create the LetStmt here; it must wrap the statement.
self.pending_defs.append((new_var, val))

# 4. Return a new Broadcast node, using the new variable to replace the original value.
return Broadcast(new_var, op.lanes)
return op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix variable naming collisions by generating unique names.

All hoisted broadcasts use the identical variable name "broadcast_var" (line 26), causing name collisions when multiple broadcasts appear in the same statement. Each hoisted value requires a unique identifier.

🔎 Proposed fix using a counter
 @tir.functor.mutator
 class HoistBroadcastValuesMutator(PyStmtExprMutator):
     def __init__(self):
         super().__init__()
         # Temporary queue: used to store variables that need to be defined within the current statement.
         self.pending_defs = []
+        # Counter for generating unique variable names
+        self.var_counter = 0
 
     def visit_broadcast_(self, op):
         if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
             # 1. Intercept Broadcast nodes.
             # Extract the value to be hoisted into a variable.
             val = self.visit_expr(op.value)
             # 2. Create a new variable.
-            new_var = Var("broadcast_var", dtype=val.dtype)
+            new_var = Var(f"broadcast_var_{self.var_counter}", dtype=val.dtype)
+            self.var_counter += 1
 
             # 3. Add the (variable, value) pair to the pending queue.
             # Note: Do not create the LetStmt here; it must wrap the statement.
             self.pending_defs.append((new_var, val))
 
             # 4. Return a new Broadcast node, using the new variable to replace the original value.
             return Broadcast(new_var, op.lanes)
         return op
🤖 Prompt for AI Agents
In @tilelang/transform/hoist_broadcast_values.py around lines 20-34, The hoisted
broadcast variables currently all use the literal name "broadcast_var" in
visit_broadcast_, causing collisions when multiple broadcasts are hoisted;
change the naming to generate a unique identifier per hoist (e.g., use a running
counter like self._broadcast_counter or incorporate a unique suffix) when
creating new_var, increment the counter each time you create a new Var, and keep
the rest of the logic (appending to self.pending_defs and returning
Broadcast(new_var, op.lanes)) unchanged so each hoisted value gets a unique
variable name.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
tilelang/transform/hoist_broadcast_values.py (3)

20-34: CRITICAL: Variable naming collision remains unfixed.

Line 26 still uses the hardcoded name "broadcast_var" for all hoisted broadcasts, causing name collisions when multiple broadcasts appear in the same statement. This issue was previously flagged but not resolved.

🔎 Proposed fix (as suggested in previous review)

Add a counter to __init__:

 def __init__(self):
     super().__init__()
     # Temporary queue: used to store variables that need to be defined within the current statement.
     self.pending_defs = []
+    # Counter for generating unique variable names
+    self.var_counter = 0

Then update line 26 to use unique names:

-        new_var = Var("broadcast_var", dtype=val.dtype)
+        new_var = Var(f"broadcast_var_{self.var_counter}", dtype=val.dtype)
+        self.var_counter += 1

36-56: MAJOR: Missing handlers for other statement types.

Lines 36-37 note that "LetStmt, Evaluate, IfThenElse, AssertStmt" should be intercepted, but only visit_buffer_store_ is implemented. Broadcasts within expressions in these other statement types will not be hoisted, limiting the pass's effectiveness.

🔎 Implementation guidance

Add analogous visitor methods following the same pattern as visit_buffer_store_:

def visit_evaluate_(self, op):
    self.pending_defs = []
    new_value = self.visit_expr(op.value)
    new_stmt = tir.Evaluate(new_value)
    if self.pending_defs:
        for var, val in reversed(self.pending_defs):
            new_stmt = LetStmt(var, val, new_stmt)
        self.pending_defs = []
    return new_stmt

def visit_let_stmt_(self, op):
    self.pending_defs = []
    new_value = self.visit_expr(op.value)
    new_body = self.visit_stmt(op.body)
    new_stmt = LetStmt(op.var, new_value, new_body)
    if self.pending_defs:
        for var, val in reversed(self.pending_defs):
            new_stmt = LetStmt(var, val, new_stmt)
        self.pending_defs = []
    return new_stmt

Similar handlers needed for visit_if_then_else_ and visit_assert_stmt_.


63-76: MINOR: Docstring example doesn't match implementation.

The example shows variable names like bv_3_14 and bv_3_14_1, but the implementation uses "broadcast_var" (line 26). Once the variable naming collision is fixed, update this example to reflect the actual naming scheme used.

🧹 Nitpick comments (1)
tilelang/transform/hoist_broadcast_values.py (1)

78-78: Optional: Prefix unused parameters with underscore.

Parameters mod and ctx are part of the prim_func_pass signature but unused in the implementation. Consider prefixing them with underscores to indicate intentional non-use:

-def pass_fn(func: PrimFunc, mod, ctx):
+def pass_fn(func: PrimFunc, _mod, _ctx):
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e8d5f07 and 467c996.

📒 Files selected for processing (1)
  • tilelang/transform/hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • tilelang/transform/hoist_broadcast_values.py
🧬 Code graph analysis (1)
tilelang/transform/hoist_broadcast_values.py (2)
tilelang/language/ast/ir.py (1)
  • LetStmt (874-902)
tilelang/language/v2/builder.py (1)
  • PrimFunc (690-699)
🪛 Ruff (0.14.10)
tilelang/transform/hoist_broadcast_values.py

78-78: Unused function argument: mod

(ARG001)


78-78: Unused function argument: ctx

(ARG001)

@LeiWang1999
Copy link
Member

@codex review

@chatgpt-codex-connector
Copy link

Codex review is not enabled for this repo. Please contact the admins of this repo to enable Codex.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to go to provide an example for this pass in testing/python/transform

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI Agents
In @testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:
- Line 30: The literal 430 assigned in the test (E[i] = 430) is not
representable in the float8_e8m0fnu format; change the broadcast value used in
test_tilelang_transform_hoist_broadcast_values.py (the E[i] = 430 assignment) to
a power-of-two value that all target dtypes can represent (e.g., 256 or 128) so
the test data is consistent across float8 variants while preserving the
kernel-source pattern being validated.
🧹 Nitpick comments (3)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (3)

9-10: Consider a more descriptive function name.

The function name qwq is unclear. Consider renaming to something more descriptive like broadcast_test_kernel or hoist_broadcast_kernel to better convey the test's purpose.

🔎 Proposed refactor
 @tilelang.jit
-def qwq(dtype=torch.float8_e4m3fn):
+def broadcast_test_kernel(dtype=torch.float8_e4m3fn):

And update the call site:

-    kernel = qwq(dtype)
+    kernel = broadcast_test_kernel(dtype)

40-41: Consider documenting the expected kernel source pattern.

The regex pattern checks for type-consistent broadcast variable declarations but relies on specific code generation format. Consider adding a comment explaining the expected pattern (e.g., float broadcast_var_1 = float(value);) to make the test more maintainable if code generation changes.

🔎 Proposed documentation improvement
+    # Expected pattern in kernel source: e.g., "float broadcast_var_1 = float(13.5);"
+    # The regex verifies that 4 immediate constants were hoisted into broadcast_var declarations
     matches = re.findall(r"(\w+) broadcast_var(_[0-9]+)? = \1", kernel.get_kernel_source())
     assert len(matches) == 4

42-47: Consider adding output validation.

The test currently validates that the kernel compiles and executes without crashing, but doesn't verify the correctness of the output values. While this may be intentional for a compiler pass test focused on code generation, adding basic output validation would strengthen confidence that the hoisted broadcast values produce correct results.

🔎 Optional enhancement for output validation
     a = torch.empty((32,), device="cuda", dtype=dtype)
     b = torch.empty((16,), device="cuda", dtype=dtype)
     c = torch.empty((8,), device="cuda", dtype=dtype)
     d = torch.empty((4,), device="cuda", dtype=dtype)
     e = torch.empty((2,), device="cuda", dtype=dtype)
     kernel(a, b, c, d, e)
+    
+    # Validate that broadcast values were correctly written
+    # Note: float8 conversions may introduce precision differences
+    assert torch.allclose(a.float(), torch.tensor(1.0).expand(32), rtol=1e-2)
+    assert torch.allclose(b.float(), torch.tensor(13.5).expand(16), rtol=1e-2)
+    assert torch.allclose(c.float(), torch.tensor(3.14).expand(8), rtol=1e-2)
+    assert torch.allclose(d.float(), torch.tensor(2.72).expand(4), rtol=1e-2)
+    assert torch.allclose(e.float(), torch.tensor(430.0).expand(2), rtol=1e-2)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 467c996 and 00c8056.

📒 Files selected for processing (1)
  • testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
🧬 Code graph analysis (1)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (2)
tilelang/language/v2/dtypes.py (4)
  • float8_e4m3fn (341-341)
  • float8_e5m2 (355-355)
  • float8_e8m0fnu (369-369)
  • float16 (299-299)
tilelang/language/proxy.py (1)
  • Tensor (232-232)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py (2)

1-6: LGTM!

The imports are appropriate for testing a compiler pass that generates CUDA kernels.


50-51: LGTM!

Standard test entry point consistent with the testing framework.

@kurisu6912 kurisu6912 linked an issue Jan 6, 2026 that may be closed by this pull request
1 task
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/transform/__init__.py (1)

10-10: LGTM! New pass export is correctly implemented.

The import properly exports HoistBroadcastValues to the public API and follows the same pattern as other re-exported transformations in this file.

Minor note: The static analysis tool suggests the # noqa: F401 directive may be unnecessary since F401 is not currently enabled in the linting configuration. However, keeping it maintains consistency with lines 5-9 and doesn't cause any issues.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 83cd936 and 3473885.

📒 Files selected for processing (2)
  • tilelang/engine/lower.py
  • tilelang/transform/__init__.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/engine/lower.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • tilelang/transform/__init__.py
🧬 Code graph analysis (1)
tilelang/transform/__init__.py (1)
tilelang/transform/hoist_broadcast_values.py (1)
  • HoistBroadcastValues (59-83)
🪛 Ruff (0.14.10)
tilelang/transform/__init__.py

10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Optimize broadcast initialization in vectorized assignments

4 participants