Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transform/common/constr_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ struct ConstrVisitor : public tir::StmtExprVisitor {
using StmtExprVisitor::VisitStmt_;
void VisitIfThenElseExpr(const PrimExpr cond, const PrimExpr true_value,
const PrimExpr false_value) {
// Visit the condition first without any guard, as it is always evaluated
// This ensures any buffer accesses in the condition are recorded
Base::VisitExpr(cond);
{
auto guard = MakeGuard(cond);
Base::VisitExpr(true_value);
Expand Down
30 changes: 30 additions & 0 deletions testing/python/transform/test_tilelang_transform_thread_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,5 +599,35 @@ def func():
assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}"


@tilelang.testing.requires_cuda
def test_sync_hoist_non_uniform_if_in_loop_with_shared_memory():
"""Test sync hoisting when non-uniform if is inside a loop with shared memory."""

@T.prim_func(private=True)
def func():
token_ids = T.alloc_buffer([128], dtype="int32", scope="shared")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
bx = T.launch_thread("blockIdx.x", 1)
tx = T.launch_thread("threadIdx.x", 128)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
for k in range(2):
# Write to shared memory
token_ids[tx] = T.int32(k - 2)
# Non-uniform if inside loop
if token_ids[tx] >= 0:
result_local[0] = T.float32(1)

mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
s = str(mod)
assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}"
# Sync should be before the if inside the loop, not inside the if
sync_pos = s.index('T.tvm_storage_sync("shared")')
if_pos = s.index("if token_ids[tx] >= 0")
assert sync_pos < if_pos, f"Sync should be hoisted before non-uniform if:\n{s}"


if __name__ == "__main__":
tilelang.testing.main()
Loading