-
Notifications
You must be signed in to change notification settings - Fork 438
[Feature] Add hoist_broadcast_values pass #1606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+197
−0
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
00d2cf4
add hoist_broadcast_values pass
silentCoder-dev e8d5f07
only hoist intImm or floatImm
silentCoder-dev 467c996
refactor
silentCoder-dev 00c8056
add test for hoist_broadcast_values
silentCoder-dev 1cc218f
add a test for issue 1601
silentCoder-dev 83cd936
add transform test for hoist_broadcast_value
silentCoder-dev 3473885
Merge branch 'main' of https://github.com/tile-ai/tilelang into hoist…
kurisu6912 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| import tilelang | ||
| import tilelang.testing | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| def test_issue_1601(): | ||
| @tilelang.jit | ||
| def qwq(): | ||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((8,), T.float8_e4m3fn), | ||
| ): | ||
| with T.Kernel(1, threads=32): | ||
| for i in T.vectorized(8): | ||
| A[i] = 0 | ||
|
|
||
| return main | ||
|
|
||
| kernel = qwq() | ||
| assert "fp8_e4_t broadcast_var = fp8_e4_t(0x0p+0f/*0.000000e+00*/);" in kernel.get_kernel_source() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
86 changes: 86 additions & 0 deletions
86
testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import torch | ||
| import re | ||
| import pytest | ||
| import tilelang.testing | ||
| from tilelang import tvm as tvm | ||
| import tilelang as tl | ||
| from tilelang.utils.target import determine_target | ||
|
|
||
|
|
||
| @tilelang.jit | ||
| def qwq(dtype=torch.float8_e4m3fn): | ||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((32,), dtype), | ||
| B: T.Tensor((16,), dtype), | ||
| C: T.Tensor((8,), dtype), | ||
| D: T.Tensor((4,), dtype), | ||
| E: T.Tensor((2,), dtype), | ||
| ): | ||
| with T.Kernel(1, threads=32): | ||
| var = T.alloc_var(dtype, 1.0) | ||
| for i in T.vectorized(32): | ||
| A[i] = var | ||
| for i in T.vectorized(16): | ||
| B[i] = 13.5 | ||
| for i in T.vectorized(8): | ||
| C[i] = 3.14 | ||
| for i in T.vectorized(4): | ||
| D[i] = 2.72 | ||
| for i in T.vectorized(2): | ||
| E[i] = 430 | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e8m0fnu, torch.float16]) | ||
silentCoder-dev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def test_hoist_broadcast(dtype): | ||
| kernel = qwq(dtype) | ||
| print(kernel.get_kernel_source()) | ||
| matches = re.findall(r"(\w+) broadcast_var(_[0-9]+)? = \1", kernel.get_kernel_source()) | ||
| assert len(matches) == 4 | ||
| a = torch.empty((32,), device="cuda", dtype=dtype) | ||
| b = torch.empty((16,), device="cuda", dtype=dtype) | ||
| c = torch.empty((8,), device="cuda", dtype=dtype) | ||
| d = torch.empty((4,), device="cuda", dtype=dtype) | ||
| e = torch.empty((2,), device="cuda", dtype=dtype) | ||
| kernel(a, b, c, d, e) | ||
|
|
||
|
|
||
| auto_target = tvm.target.Target(determine_target("auto")) | ||
|
|
||
|
|
||
| def _check(original, transformed): | ||
| mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) | ||
| mod = tvm.tir.transform.BindTarget(auto_target)(mod) | ||
| mod = tl.transform.HoistBroadcastValues()(mod) | ||
|
|
||
| transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main")) | ||
| transformed = tvm.tir.transform.BindTarget(auto_target)(transformed) | ||
|
|
||
| tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True) | ||
|
|
||
|
|
||
| def test_transform_hoist(): | ||
| @T.prim_func | ||
| def before(): | ||
| with T.Kernel(8): | ||
| A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") | ||
| A_shared[0:8] = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) | ||
|
|
||
| @T.prim_func | ||
| def after(): | ||
| with T.Kernel(8): | ||
| A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") | ||
| broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) | ||
| broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) | ||
| A_shared[0:8] = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) | ||
|
|
||
| _check(before, after) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from tvm import tir | ||
| from tvm.tir import ( | ||
| BufferStore, | ||
| LetStmt, | ||
| Broadcast, | ||
| Var, | ||
| PrimFunc, | ||
| PyStmtExprMutator, | ||
| ) | ||
| from tvm.tir.transform import prim_func_pass | ||
|
|
||
|
|
||
| @tir.functor.mutator | ||
| class HoistBroadcastValuesMutator(PyStmtExprMutator): | ||
| def __init__(self): | ||
| super().__init__() | ||
| # Temporary queue: used to store variables that need to be defined within the current statement. | ||
| self.pending_defs = [] | ||
|
|
||
| def visit_broadcast_(self, op): | ||
| if isinstance(op.value, (tir.IntImm, tir.FloatImm)): | ||
| # 1. Intercept Broadcast nodes. | ||
| # Extract the value to be hoisted into a variable. | ||
| val = self.visit_expr(op.value) | ||
| # 2. Create a new variable. | ||
| new_var = Var("broadcast_var", dtype=val.dtype) | ||
|
|
||
| # 3. Add the (variable, value) pair to the pending queue. | ||
| # Note: Do not create the LetStmt here; it must wrap the statement. | ||
| self.pending_defs.append((new_var, val)) | ||
|
|
||
| # 4. Return a new Broadcast node, using the new variable to replace the original value. | ||
| return Broadcast(new_var, op.lanes) | ||
| return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes)) | ||
|
|
||
| # Must intercept all Statements that might contain Expressions. | ||
| # Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt. | ||
| def visit_buffer_store_(self, op: BufferStore): | ||
| # 1. Clear the pending queue for the current statement context. | ||
| self.pending_defs = [] | ||
|
|
||
| # 2. Visit child nodes normally (this will trigger visit_broadcast_). | ||
| new_indices = [self.visit_expr(idx) for idx in op.indices] | ||
| new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices) | ||
|
|
||
| # 3. Check if there are variables waiting to be defined. | ||
| if self.pending_defs: | ||
| # 4. Wrap the current statement with LetStmt. | ||
| # Order: Traverse in reverse to ensure the first definition wraps the outermost layer. | ||
| # Structure generated: Let my_var = val In BufferStore(...) | ||
| for var, val in reversed(self.pending_defs): | ||
| new_stmt = LetStmt(var, val, new_stmt) | ||
|
|
||
| # Clear the queue for the next statement. | ||
| self.pending_defs = [] | ||
| return new_stmt | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def HoistBroadcastValues(): | ||
| """ | ||
| TVM Pass: HoistBroadcastValues. | ||
|
|
||
| This pass scans the TIR for Broadcast operations involving immediate constants (IntImm, FloatImm). | ||
| It extracts these constants into variables defined via LetStmt immediately surrounding | ||
| the statement where the broadcast occurs. | ||
|
|
||
| Example Transformation: | ||
| ----------------------- | ||
| Before: | ||
| A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4) | ||
|
|
||
| After: | ||
| bv_3_14 = 3.14 | ||
| bv_3_14_1 = 3.14 | ||
| A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4) | ||
| """ | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def pass_fn(func: PrimFunc, mod, ctx): | ||
| mutator = HoistBroadcastValuesMutator() | ||
| new_body = mutator.visit_stmt(func.body) | ||
| return func.with_body(new_body) | ||
|
|
||
| return prim_func_pass(pass_fn, opt_level=0) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.