From b59203aa0d78acce9eb9976b707e03b176e36350 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 8 Jan 2026 14:13:15 +0800 Subject: [PATCH 1/4] [Enhancement] Implement hoisting of broadcast values in Let statements and update tests for validation --- ...lelang_transform_hoist_broadcast_values.py | 20 +++++++++ tilelang/transform/hoist_broadcast_values.py | 43 +++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py index 5f7d23b53..e5c394f2b 100644 --- a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py +++ b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py @@ -82,5 +82,25 @@ def after(): _check(before, after) +def test_transform_hoist_let_stmt(): + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + val: T.float8x8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) + A_shared[0:8] = val + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) + broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) + val: T.float8x8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) + A_shared[0:8] = val + + _check(before, after) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/transform/hoist_broadcast_values.py b/tilelang/transform/hoist_broadcast_values.py index 35854f522..a1c06289a 100644 --- a/tilelang/transform/hoist_broadcast_values.py +++ b/tilelang/transform/hoist_broadcast_values.py @@ -16,9 +16,11 @@ def __init__(self): super().__init__() # Temporary queue: used to store variables that need to be defined within the current statement. self.pending_defs = [] + # Flag to indicate if hoist should be enabled. + self.hoist_enabled = False def visit_broadcast_(self, op): - if isinstance(op.value, (tir.IntImm, tir.FloatImm)): + if self.hoist_enabled and isinstance(op.value, (tir.IntImm, tir.FloatImm)): # 1. Intercept Broadcast nodes. # Extract the value to be hoisted into a variable. val = self.visit_expr(op.value) @@ -36,16 +38,20 @@ def visit_broadcast_(self, op): # Must intercept all Statements that might contain Expressions. # Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt. def visit_buffer_store_(self, op: BufferStore): - # 1. Clear the pending queue for the current statement context. + # 1. Enable hoist flag and clear the pending queue for the current statement context. + self.hoist_enabled = True self.pending_defs = [] # 2. Visit child nodes normally (this will trigger visit_broadcast_). new_indices = [self.visit_expr(idx) for idx in op.indices] new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices) - # 3. Check if there are variables waiting to be defined. + # 3. Disable hoist flag after visiting. + self.hoist_enabled = False + + # 4. Check if there are variables waiting to be defined. if self.pending_defs: - # 4. Wrap the current statement with LetStmt. + # 5. Wrap the current statement with LetStmt. # Order: Traverse in reverse to ensure the first definition wraps the outermost layer. # Structure generated: Let my_var = val In BufferStore(...) for var, val in reversed(self.pending_defs): @@ -53,6 +59,35 @@ def visit_buffer_store_(self, op: BufferStore): # Clear the queue for the next statement. self.pending_defs = [] + print(f"new_stmt: {new_stmt}") + return new_stmt + + def visit_let_stmt_(self, op: LetStmt): + # 1. Enable hoist flag and clear the pending queue for the current statement context. + self.hoist_enabled = True + self.pending_defs = [] + + # 2. Visit the value expression (this will trigger visit_broadcast_). + new_value = self.visit_expr(op.value) + + # 3. Disable hoist flag after visiting value. + self.hoist_enabled = False + + # 4. Recursively visit the body. + new_body = self.visit_stmt(op.body) + + # 5. Create the new LetStmt. + new_stmt = LetStmt(op.var, new_value, new_body) + + # 6. Check if there are variables waiting to be defined. + if self.pending_defs: + # 7. Wrap the current statement with LetStmt. + for var, val in reversed(self.pending_defs): + new_stmt = LetStmt(var, val, new_stmt) + + # Clear the queue for the next statement. + self.pending_defs = [] + print(f"new_stmt: {new_stmt}") return new_stmt From 9215ef862ce52cb865e566db4f2de36f54c3767a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 8 Jan 2026 14:15:35 +0800 Subject: [PATCH 2/4] lint fix --- tilelang/transform/hoist_broadcast_values.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tilelang/transform/hoist_broadcast_values.py b/tilelang/transform/hoist_broadcast_values.py index a1c06289a..af8438d32 100644 --- a/tilelang/transform/hoist_broadcast_values.py +++ b/tilelang/transform/hoist_broadcast_values.py @@ -59,7 +59,6 @@ def visit_buffer_store_(self, op: BufferStore): # Clear the queue for the next statement. self.pending_defs = [] - print(f"new_stmt: {new_stmt}") return new_stmt def visit_let_stmt_(self, op: LetStmt): @@ -87,7 +86,6 @@ def visit_let_stmt_(self, op: LetStmt): # Clear the queue for the next statement. self.pending_defs = [] - print(f"new_stmt: {new_stmt}") return new_stmt From 5a59f4ff10e33e3c3e6044388e87fb44681b8ebf Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 8 Jan 2026 16:45:23 +0800 Subject: [PATCH 3/4] [Enhancement] Update hoisting logic for broadcast values in Let statements and add a new test case for nested BufferStore broadcasts --- ...lelang_transform_hoist_broadcast_values.py | 36 ++++++++++++- tilelang/transform/hoist_broadcast_values.py | 53 ++++++++++++------- 2 files changed, 67 insertions(+), 22 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py index e5c394f2b..ecab9d50f 100644 --- a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py +++ b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py @@ -87,7 +87,7 @@ def test_transform_hoist_let_stmt(): def before(): with T.Kernel(8): A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") - val: T.float8x8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) + val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) A_shared[0:8] = val @T.prim_func @@ -96,11 +96,43 @@ def after(): A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) - val: T.float8x8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) + val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) A_shared[0:8] = val _check(before, after) +def test_transform_hoist_let_stmt_with_nested_bufferstore_broadcasts(): + """Test case for the bug where BufferStore in LetStmt body clears pending_defs. + + This test validates that broadcasts hoisted from a LetStmt's value expression + are preserved even when the body contains a BufferStore with additional broadcasts. + """ + + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + # LetStmt value has broadcasts + val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast( + T.float8_e4m3fn(3.4), 8) + # Body is a BufferStore with additional broadcasts + A_shared[0:8] = val + T.Broadcast(T.float8_e4m3fn(5.6), 8) + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + # Hoisted from LetStmt value + broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) + broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) + val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) + # Hoisted from BufferStore + broadcast_var_2: T.float8_e4m3fn = T.float8_e4m3fn(5.6) + A_shared[0:8] = val + T.Broadcast(broadcast_var_2, 8) + + _check(before, after) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/transform/hoist_broadcast_values.py b/tilelang/transform/hoist_broadcast_values.py index af8438d32..91073ecbf 100644 --- a/tilelang/transform/hoist_broadcast_values.py +++ b/tilelang/transform/hoist_broadcast_values.py @@ -35,20 +35,21 @@ def visit_broadcast_(self, op): return Broadcast(new_var, op.lanes) return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes)) - # Must intercept all Statements that might contain Expressions. - # Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt. + # Intercept statement types that might contain expressions with broadcasts. + # Currently handled: BufferStore, LetStmt. def visit_buffer_store_(self, op: BufferStore): - # 1. Enable hoist flag and clear the pending queue for the current statement context. + # 1. Save the current state to handle nested statements correctly. + saved_hoist_enabled = self.hoist_enabled + saved_pending_defs = self.pending_defs + + # 2. Enable hoist flag and clear the pending queue for the current statement context. self.hoist_enabled = True self.pending_defs = [] - # 2. Visit child nodes normally (this will trigger visit_broadcast_). + # 3. Visit child nodes normally (this will trigger visit_broadcast_). new_indices = [self.visit_expr(idx) for idx in op.indices] new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices) - # 3. Disable hoist flag after visiting. - self.hoist_enabled = False - # 4. Check if there are variables waiting to be defined. if self.pending_defs: # 5. Wrap the current statement with LetStmt. @@ -57,35 +58,47 @@ def visit_buffer_store_(self, op: BufferStore): for var, val in reversed(self.pending_defs): new_stmt = LetStmt(var, val, new_stmt) - # Clear the queue for the next statement. - self.pending_defs = [] + # 6. Restore the saved state. + self.hoist_enabled = saved_hoist_enabled + self.pending_defs = saved_pending_defs + return new_stmt def visit_let_stmt_(self, op: LetStmt): - # 1. Enable hoist flag and clear the pending queue for the current statement context. + # 1. Save the current state to handle nested statements correctly. + saved_hoist_enabled = self.hoist_enabled + saved_pending_defs = self.pending_defs + + # 2. Enable hoist flag and clear the pending queue for the current statement context. self.hoist_enabled = True self.pending_defs = [] - # 2. Visit the value expression (this will trigger visit_broadcast_). + # 3. Visit the value expression (this will trigger visit_broadcast_). new_value = self.visit_expr(op.value) - # 3. Disable hoist flag after visiting value. + # 4. Capture the pending defs from the value expression before visiting body. + value_pending_defs = self.pending_defs + + # 5. Disable hoist flag and clear pending defs before visiting body. self.hoist_enabled = False + self.pending_defs = [] - # 4. Recursively visit the body. + # 6. Recursively visit the body. new_body = self.visit_stmt(op.body) - # 5. Create the new LetStmt. + # 7. Create the new LetStmt. new_stmt = LetStmt(op.var, new_value, new_body) - # 6. Check if there are variables waiting to be defined. - if self.pending_defs: - # 7. Wrap the current statement with LetStmt. - for var, val in reversed(self.pending_defs): + # 8. Check if there are variables waiting to be defined from the value expression. + if value_pending_defs: + # 9. Wrap the current statement with LetStmt. + for var, val in reversed(value_pending_defs): new_stmt = LetStmt(var, val, new_stmt) - # Clear the queue for the next statement. - self.pending_defs = [] + # 10. Restore the saved state. + self.hoist_enabled = saved_hoist_enabled + self.pending_defs = saved_pending_defs + return new_stmt From 341c1fb7d0b55502936a19e766f54a3cddb24981 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 8 Jan 2026 16:46:21 +0800 Subject: [PATCH 4/4] lint fix --- .../test_tilelang_transform_hoist_broadcast_values.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py index ecab9d50f..dd85ecaa1 100644 --- a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py +++ b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py @@ -114,8 +114,7 @@ def before(): with T.Kernel(8): A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") # LetStmt value has broadcasts - val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast( - T.float8_e4m3fn(3.4), 8) + val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) # Body is a BufferStore with additional broadcasts A_shared[0:8] = val + T.Broadcast(T.float8_e4m3fn(5.6), 8)