Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,56 @@ def after():
_check(before, after)


def test_transform_hoist_let_stmt():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8)
A_shared[0:8] = val

@T.prim_func
def after():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2)
broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4)
val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8)
A_shared[0:8] = val

_check(before, after)


def test_transform_hoist_let_stmt_with_nested_bufferstore_broadcasts():
"""Test case for the bug where BufferStore in LetStmt body clears pending_defs.

This test validates that broadcasts hoisted from a LetStmt's value expression
are preserved even when the body contains a BufferStore with additional broadcasts.
"""

@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
# LetStmt value has broadcasts
val: T.float8_e4m3fnx8 = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8)
# Body is a BufferStore with additional broadcasts
A_shared[0:8] = val + T.Broadcast(T.float8_e4m3fn(5.6), 8)

@T.prim_func
def after():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
# Hoisted from LetStmt value
broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2)
broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4)
val: T.float8_e4m3fnx8 = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8)
# Hoisted from BufferStore
broadcast_var_2: T.float8_e4m3fn = T.float8_e4m3fn(5.6)
A_shared[0:8] = val + T.Broadcast(broadcast_var_2, 8)

_check(before, after)


if __name__ == "__main__":
tilelang.testing.main()
64 changes: 55 additions & 9 deletions tilelang/transform/hoist_broadcast_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def __init__(self):
super().__init__()
# Temporary queue: used to store variables that need to be defined within the current statement.
self.pending_defs = []
# Flag to indicate if hoist should be enabled.
self.hoist_enabled = False

def visit_broadcast_(self, op):
if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
if self.hoist_enabled and isinstance(op.value, (tir.IntImm, tir.FloatImm)):
# 1. Intercept Broadcast nodes.
# Extract the value to be hoisted into a variable.
val = self.visit_expr(op.value)
Expand All @@ -33,26 +35,70 @@ def visit_broadcast_(self, op):
return Broadcast(new_var, op.lanes)
return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes))

# Must intercept all Statements that might contain Expressions.
# Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt.
# Intercept statement types that might contain expressions with broadcasts.
# Currently handled: BufferStore, LetStmt.
def visit_buffer_store_(self, op: BufferStore):
# 1. Clear the pending queue for the current statement context.
# 1. Save the current state to handle nested statements correctly.
saved_hoist_enabled = self.hoist_enabled
saved_pending_defs = self.pending_defs

# 2. Enable hoist flag and clear the pending queue for the current statement context.
self.hoist_enabled = True
self.pending_defs = []

# 2. Visit child nodes normally (this will trigger visit_broadcast_).
# 3. Visit child nodes normally (this will trigger visit_broadcast_).
new_indices = [self.visit_expr(idx) for idx in op.indices]
new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices)

# 3. Check if there are variables waiting to be defined.
# 4. Check if there are variables waiting to be defined.
if self.pending_defs:
# 4. Wrap the current statement with LetStmt.
# 5. Wrap the current statement with LetStmt.
# Order: Traverse in reverse to ensure the first definition wraps the outermost layer.
# Structure generated: Let my_var = val In BufferStore(...)
for var, val in reversed(self.pending_defs):
new_stmt = LetStmt(var, val, new_stmt)

# Clear the queue for the next statement.
self.pending_defs = []
# 6. Restore the saved state.
self.hoist_enabled = saved_hoist_enabled
self.pending_defs = saved_pending_defs

return new_stmt

def visit_let_stmt_(self, op: LetStmt):
# 1. Save the current state to handle nested statements correctly.
saved_hoist_enabled = self.hoist_enabled
saved_pending_defs = self.pending_defs

# 2. Enable hoist flag and clear the pending queue for the current statement context.
self.hoist_enabled = True
self.pending_defs = []

# 3. Visit the value expression (this will trigger visit_broadcast_).
new_value = self.visit_expr(op.value)

# 4. Capture the pending defs from the value expression before visiting body.
value_pending_defs = self.pending_defs

# 5. Disable hoist flag and clear pending defs before visiting body.
self.hoist_enabled = False
self.pending_defs = []

# 6. Recursively visit the body.
new_body = self.visit_stmt(op.body)

# 7. Create the new LetStmt.
new_stmt = LetStmt(op.var, new_value, new_body)

# 8. Check if there are variables waiting to be defined from the value expression.
if value_pending_defs:
# 9. Wrap the current statement with LetStmt.
for var, val in reversed(value_pending_defs):
new_stmt = LetStmt(var, val, new_stmt)

# 10. Restore the saved state.
self.hoist_enabled = saved_hoist_enabled
self.pending_defs = saved_pending_defs

return new_stmt


Expand Down
Loading