diff --git a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py index 5f7d23b53..dd85ecaa1 100644 --- a/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py +++ b/testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py @@ -82,5 +82,56 @@ def after(): _check(before, after) +def test_transform_hoist_let_stmt(): + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) + A_shared[0:8] = val + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) + broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) + val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) + A_shared[0:8] = val + + _check(before, after) + + +def test_transform_hoist_let_stmt_with_nested_bufferstore_broadcasts(): + """Test case for the bug where BufferStore in LetStmt body clears pending_defs. + + This test validates that broadcasts hoisted from a LetStmt's value expression + are preserved even when the body contains a BufferStore with additional broadcasts. + """ + + @T.prim_func + def before(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + # LetStmt value has broadcasts + val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8) + # Body is a BufferStore with additional broadcasts + A_shared[0:8] = val + T.Broadcast(T.float8_e4m3fn(5.6), 8) + + @T.prim_func + def after(): + with T.Kernel(8): + A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn") + # Hoisted from LetStmt value + broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2) + broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4) + val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8) + # Hoisted from BufferStore + broadcast_var_2: T.float8_e4m3fn = T.float8_e4m3fn(5.6) + A_shared[0:8] = val + T.Broadcast(broadcast_var_2, 8) + + _check(before, after) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/transform/hoist_broadcast_values.py b/tilelang/transform/hoist_broadcast_values.py index 35854f522..91073ecbf 100644 --- a/tilelang/transform/hoist_broadcast_values.py +++ b/tilelang/transform/hoist_broadcast_values.py @@ -16,9 +16,11 @@ def __init__(self): super().__init__() # Temporary queue: used to store variables that need to be defined within the current statement. self.pending_defs = [] + # Flag to indicate if hoist should be enabled. + self.hoist_enabled = False def visit_broadcast_(self, op): - if isinstance(op.value, (tir.IntImm, tir.FloatImm)): + if self.hoist_enabled and isinstance(op.value, (tir.IntImm, tir.FloatImm)): # 1. Intercept Broadcast nodes. # Extract the value to be hoisted into a variable. val = self.visit_expr(op.value) @@ -33,26 +35,70 @@ def visit_broadcast_(self, op): return Broadcast(new_var, op.lanes) return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes)) - # Must intercept all Statements that might contain Expressions. - # Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt. + # Intercept statement types that might contain expressions with broadcasts. + # Currently handled: BufferStore, LetStmt. def visit_buffer_store_(self, op: BufferStore): - # 1. Clear the pending queue for the current statement context. + # 1. Save the current state to handle nested statements correctly. + saved_hoist_enabled = self.hoist_enabled + saved_pending_defs = self.pending_defs + + # 2. Enable hoist flag and clear the pending queue for the current statement context. + self.hoist_enabled = True self.pending_defs = [] - # 2. Visit child nodes normally (this will trigger visit_broadcast_). + # 3. Visit child nodes normally (this will trigger visit_broadcast_). new_indices = [self.visit_expr(idx) for idx in op.indices] new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices) - # 3. Check if there are variables waiting to be defined. + # 4. Check if there are variables waiting to be defined. if self.pending_defs: - # 4. Wrap the current statement with LetStmt. + # 5. Wrap the current statement with LetStmt. # Order: Traverse in reverse to ensure the first definition wraps the outermost layer. # Structure generated: Let my_var = val In BufferStore(...) for var, val in reversed(self.pending_defs): new_stmt = LetStmt(var, val, new_stmt) - # Clear the queue for the next statement. - self.pending_defs = [] + # 6. Restore the saved state. + self.hoist_enabled = saved_hoist_enabled + self.pending_defs = saved_pending_defs + + return new_stmt + + def visit_let_stmt_(self, op: LetStmt): + # 1. Save the current state to handle nested statements correctly. + saved_hoist_enabled = self.hoist_enabled + saved_pending_defs = self.pending_defs + + # 2. Enable hoist flag and clear the pending queue for the current statement context. + self.hoist_enabled = True + self.pending_defs = [] + + # 3. Visit the value expression (this will trigger visit_broadcast_). + new_value = self.visit_expr(op.value) + + # 4. Capture the pending defs from the value expression before visiting body. + value_pending_defs = self.pending_defs + + # 5. Disable hoist flag and clear pending defs before visiting body. + self.hoist_enabled = False + self.pending_defs = [] + + # 6. Recursively visit the body. + new_body = self.visit_stmt(op.body) + + # 7. Create the new LetStmt. + new_stmt = LetStmt(op.var, new_value, new_body) + + # 8. Check if there are variables waiting to be defined from the value expression. + if value_pending_defs: + # 9. Wrap the current statement with LetStmt. + for var, val in reversed(value_pending_defs): + new_stmt = LetStmt(var, val, new_stmt) + + # 10. Restore the saved state. + self.hoist_enabled = saved_hoist_enabled + self.pending_defs = saved_pending_defs + return new_stmt