diff --git a/src/transform/common/constr_visitor.h b/src/transform/common/constr_visitor.h index af7ae36d6..a87f7313d 100644 --- a/src/transform/common/constr_visitor.h +++ b/src/transform/common/constr_visitor.h @@ -170,6 +170,9 @@ struct ConstrVisitor : public tir::StmtExprVisitor { using StmtExprVisitor::VisitStmt_; void VisitIfThenElseExpr(const PrimExpr cond, const PrimExpr true_value, const PrimExpr false_value) { + // Visit the condition first without any guard, as it is always evaluated + // This ensures any buffer accesses in the condition are recorded + Base::VisitExpr(cond); { auto guard = MakeGuard(cond); Base::VisitExpr(true_value); diff --git a/testing/python/transform/test_tilelang_transform_thread_sync.py b/testing/python/transform/test_tilelang_transform_thread_sync.py index 8b2901571..9d20de103 100644 --- a/testing/python/transform/test_tilelang_transform_thread_sync.py +++ b/testing/python/transform/test_tilelang_transform_thread_sync.py @@ -599,5 +599,35 @@ def func(): assert 'T.tvm_storage_sync("shared")' not in s, f"Unexpected sync:\n{s}" +@tilelang.testing.requires_cuda +def test_sync_hoist_non_uniform_if_in_loop_with_shared_memory(): + """Test sync hoisting when non-uniform if is inside a loop with shared memory.""" + + @T.prim_func(private=True) + def func(): + token_ids = T.alloc_buffer([128], dtype="int32", scope="shared") + result_local = T.alloc_buffer([1], dtype="float32", scope="local") + bx = T.launch_thread("blockIdx.x", 1) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + result_local[0] = T.float32(0) + for k in range(2): + # Write to shared memory + token_ids[tx] = T.int32(k - 2) + # Non-uniform if inside loop + if token_ids[tx] >= 0: + result_local[0] = T.float32(1) + + mod = tvm.IRModule({"main": func}) + mod = tilelang.transform.ThreadSync("shared")(mod) + s = str(mod) + assert 'T.tvm_storage_sync("shared")' in s, f"Expected sync:\n{s}" + # Sync should be before the if inside the loop, not inside the if + sync_pos = s.index('T.tvm_storage_sync("shared")') + if_pos = s.index("if token_ids[tx] >= 0") + assert sync_pos < if_pos, f"Sync should be hoisted before non-uniform if:\n{s}" + + if __name__ == "__main__": tilelang.testing.main()