diff --git a/src/op/builtin.cc b/src/op/builtin.cc index d8122aa77..7bbc44605 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -41,6 +41,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTG, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTGPredicated, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableLoopUnswitching, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kLoopUnswitchingAllowNonTrivialElse, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableOutOfBoundWarning, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index 5da3b521a..464f3d39c 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -57,6 +57,10 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kDisableLoopUnswitching = "tl.disable_loop_unswitching"; +// Allow loop unswitching even when the else-version of the loop body is +// non-trivial (has side effects). Default: false (conservative). +static constexpr const char *kLoopUnswitchingAllowNonTrivialElse = + "tl.loop_unswitching_allow_non_trivial_else"; /*! * \brief Enable lowering non-predicated global load/store to ldg/stg intrinsics diff --git a/src/transform/loop_unswitching.cc b/src/transform/loop_unswitching.cc index 7767866bb..166f4ac69 100644 --- a/src/transform/loop_unswitching.cc +++ b/src/transform/loop_unswitching.cc @@ -215,13 +215,156 @@ bool UsesLoopVarThroughLetBindings( return false; } +/*! + * \brief Check if an expression uses any variable in \p vars (directly or + * through Let bindings). + * + * This is similar to UsesLoopVarThroughLetBindings, but generalized to a set of + * variables. It is used to conservatively block unswitching on per-thread + * predicates (e.g. threadIdx.x) because later passes may insert synchronization + * calls that would become control-flow dependent after unswitching. + */ +bool UsesVarsThroughLetBindingsImpl( + const PrimExpr &expr, const std::unordered_set &vars, + const std::unordered_map *let_bindings, + std::unordered_set *visited_let_vars) { + if (vars.empty()) { + return false; + } + + // Direct use in expr + if (UsesVar(expr, [&](const VarNode *v) { return vars.count(v); })) { + return true; + } + + if (!let_bindings) { + return false; + } + + bool uses = false; + PostOrderVisit(expr, [&](const ObjectRef &obj) { + if (uses) { + return; + } + const auto *var_node = obj.as(); + if (!var_node) { + return; + } + auto it = let_bindings->find(var_node); + if (it == let_bindings->end()) { + return; + } + if (visited_let_vars && visited_let_vars->count(var_node)) { + return; + } + if (visited_let_vars) { + visited_let_vars->insert(var_node); + } + if (UsesVarsThroughLetBindingsImpl(it->second, vars, let_bindings, + visited_let_vars)) { + uses = true; + } + }); + + return uses; +} + +bool UsesVarsThroughLetBindings( + const PrimExpr &expr, const std::unordered_set &vars, + const std::unordered_map *let_bindings) { + std::unordered_set visited_let_vars; + return UsesVarsThroughLetBindingsImpl(expr, vars, let_bindings, + &visited_let_vars); +} + +/*! + * \brief Check if a statement is side-effect free (i.e. a no-op), allowing only + * pure/read-only expression evaluation. + * + * This is intentionally conservative, and is used as a profitability/safety + * guard: only unswitch when the "else version" of the loop body does not + * perform any meaningful work. This keeps the common pattern + * + * for i: if cond: S(i) + * + * while avoiding code-size blowup and control-flow complexity for + * + * for i: if cond: S1(i) else: S2(i) + * + * or when there are other side-effecting statements outside the hoisted if. + */ +bool IsSideEffectFreeStmt(const Stmt &stmt) { + if (!stmt.defined()) { + return true; + } + + if (const auto *op = stmt.as()) { + // Treat pure or read-only evaluation as no-op. + return SideEffect(op->value) <= CallEffectKind::kReadState; + } + + if (const auto *op = stmt.as()) { + for (const Stmt &s : op->seq) { + if (!IsSideEffectFreeStmt(s)) { + return false; + } + } + return true; + } + + if (const auto *op = stmt.as()) { + if (SideEffect(op->value) > CallEffectKind::kReadState) { + return false; + } + return IsSideEffectFreeStmt(op->body); + } + + if (const auto *op = stmt.as()) { + if (SideEffect(op->condition) > CallEffectKind::kReadState) { + return false; + } + if (!IsSideEffectFreeStmt(op->then_case)) { + return false; + } + if (op->else_case.defined() && + !IsSideEffectFreeStmt(op->else_case.value())) { + return false; + } + return true; + } + + if (const auto *op = stmt.as()) { + if (SideEffect(op->min) > CallEffectKind::kReadState || + SideEffect(op->extent) > CallEffectKind::kReadState) { + return false; + } + return IsSideEffectFreeStmt(op->body); + } + + // Conservatively treat all other statements as side-effecting. + return false; +} + /*! * \brief Check if a condition is loop-invariant */ -bool IsLoopInvariant(const PrimExpr &cond, const Var &loop_var, - const std::unordered_set &written_vars, - const std::unordered_map - *let_bindings = nullptr) { +bool IsLoopInvariant( + const PrimExpr &cond, const Var &loop_var, + const std::unordered_set &written_vars, + const std::unordered_map *let_bindings = nullptr, + const std::unordered_set *disallowed_vars = nullptr) { + // Check 0: disallow conditions that depend on per-thread binding vars (e.g. + // threadIdx.x). These predicates are loop-invariant, but unswitching them can + // split the execution into different code paths across threads. Later passes + // (e.g. thread sync insertion, fence proxy injection) may add synchronization + // calls outside the hoisted if, which would become control-flow dependent and + // lead to incorrect codegen. + if (disallowed_vars && !disallowed_vars->empty()) { + if (UsesVarsThroughLetBindings(cond, *disallowed_vars, let_bindings)) { + return false; + } + } + // Check 1: must not use loop variable (directly or through Let bindings) if (UsesLoopVarThroughLetBindings(cond, loop_var, let_bindings)) { return false; @@ -329,13 +472,16 @@ class HoistableIfFinder : public StmtVisitor { const IfThenElseNode *found = nullptr; const Var &loop_var; const std::unordered_set &written_vars; + const std::unordered_set *disallowed_vars; std::unordered_map let_bindings_; // Let bindings that need to be hoisted with the condition std::vector> hoisted_let_bindings; HoistableIfFinder(const Var &loop_var, - const std::unordered_set &written_vars) - : loop_var(loop_var), written_vars(written_vars) {} + const std::unordered_set &written_vars, + const std::unordered_set *disallowed_vars) + : loop_var(loop_var), written_vars(written_vars), + disallowed_vars(disallowed_vars) {} void VisitStmt_(const LetStmtNode *op) final { // Track ALL Let bindings to detect when a condition uses a variable @@ -352,8 +498,8 @@ class HoistableIfFinder : public StmtVisitor { void VisitStmt_(const IfThenElseNode *op) final { if (found) return; - if (IsLoopInvariant(op->condition, loop_var, written_vars, - &let_bindings_)) { + if (IsLoopInvariant(op->condition, loop_var, written_vars, &let_bindings_, + disallowed_vars)) { found = op; // Collect Let-bound variables used in the condition LetVarCollector collector(let_bindings_); @@ -374,7 +520,22 @@ class HoistableIfFinder : public StmtVisitor { */ class LoopUnswitcher : public StmtExprMutator { public: + explicit LoopUnswitcher(bool allow_non_trivial_else) + : allow_non_trivial_else_(allow_non_trivial_else) {} + + std::unordered_set thread_idx_vars_in_scope_; + Stmt VisitStmt_(const ForNode *op) final { + bool pushed_thread_idx = false; + if (op->thread_binding.defined()) { + String thread_tag = op->thread_binding.value()->thread_tag; + if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" || + thread_tag == "threadIdx.z") { + thread_idx_vars_in_scope_.insert(op->loop_var.get()); + pushed_thread_idx = true; + } + } + // Bottom-up: process nested structures first Stmt body = VisitStmt(op->body); @@ -383,15 +544,22 @@ class LoopUnswitcher : public StmtExprMutator { collector(body); // Find hoistable if - HoistableIfFinder finder(op->loop_var, collector.written); + HoistableIfFinder finder(op->loop_var, collector.written, + &thread_idx_vars_in_scope_); finder(body); + Stmt result; if (!finder.found) { if (body.same_as(op->body)) { - return ffi::GetRef(op); + result = ffi::GetRef(op); + } else { + result = For(op->loop_var, op->min, op->extent, op->kind, body, + op->thread_binding, op->annotations); + } + if (pushed_thread_idx) { + thread_idx_vars_in_scope_.erase(op->loop_var.get()); } - return For(op->loop_var, op->min, op->extent, op->kind, body, - op->thread_binding, op->annotations); + return result; } // Check if there are any function calls OUTSIDE the hoisted if statement. @@ -403,10 +571,15 @@ class LoopUnswitcher : public StmtExprMutator { call_checker(body); if (call_checker.has_call) { if (body.same_as(op->body)) { - return ffi::GetRef(op); + result = ffi::GetRef(op); + } else { + result = For(op->loop_var, op->min, op->extent, op->kind, body, + op->thread_binding, op->annotations); } - return For(op->loop_var, op->min, op->extent, op->kind, body, - op->thread_binding, op->annotations); + if (pushed_thread_idx) { + thread_idx_vars_in_scope_.erase(op->loop_var.get()); + } + return result; } // Unswitch: create two loop versions @@ -418,6 +591,19 @@ class LoopUnswitcher : public StmtExprMutator { Stmt else_body = IfBranchReplacer(hoisted_condition, false, finder.hoisted_let_bindings)(body); + // Only unswitch when the else-version does not do any meaningful work. + // This keeps the canonical optimization `for: if(cond) {S}` -> + // `if(cond){for:S}` while avoiding duplicating non-trivial loop bodies into + // two versions. + if (!allow_non_trivial_else_ && !IsSideEffectFreeStmt(else_body)) { + result = For(op->loop_var, op->min, op->extent, op->kind, body, + op->thread_binding, op->annotations); + if (pushed_thread_idx) { + thread_idx_vars_in_scope_.erase(op->loop_var.get()); + } + return result; + } + // Create new loop_var for else_loop to maintain SSA form Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype); else_body = Substitute(else_body, {{op->loop_var, else_loop_var}}); @@ -427,7 +613,7 @@ class LoopUnswitcher : public StmtExprMutator { For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body, op->thread_binding, op->annotations); - Stmt result = IfThenElse(if_node->condition, then_loop, else_loop); + result = IfThenElse(if_node->condition, then_loop, else_loop); // Wrap with hoisted Let bindings (in reverse order so first binding is // outermost) @@ -436,14 +622,20 @@ class LoopUnswitcher : public StmtExprMutator { result = LetStmt(it->first, it->second, result); } + if (pushed_thread_idx) { + thread_idx_vars_in_scope_.erase(op->loop_var.get()); + } return result; } + +private: + bool allow_non_trivial_else_{false}; }; // --- Public API --- -Stmt ApplyLoopUnswitching(Stmt stmt) { - return LoopUnswitcher()(std::move(stmt)); +Stmt ApplyLoopUnswitching(Stmt stmt, bool allow_non_trivial_else) { + return LoopUnswitcher(allow_non_trivial_else)(std::move(stmt)); } using namespace tir::transform; @@ -455,7 +647,11 @@ tvm::transform::Pass LoopUnswitching() { if (disable_loop_unswitching) { return f; } - f.CopyOnWrite()->body = ApplyLoopUnswitching(f->body); + bool allow_non_trivial_else = + ctx->GetConfig(kLoopUnswitchingAllowNonTrivialElse, Bool(false)) + .value(); + f.CopyOnWrite()->body = + ApplyLoopUnswitching(f->body, allow_non_trivial_else); return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.LoopUnswitching", {}); diff --git a/testing/python/transform/test_tilelang_transform_loop_unswitching.py b/testing/python/transform/test_tilelang_transform_loop_unswitching.py index 212b57732..8909a3d2e 100644 --- a/testing/python/transform/test_tilelang_transform_loop_unswitching.py +++ b/testing/python/transform/test_tilelang_transform_loop_unswitching.py @@ -2,6 +2,7 @@ import tilelang as tl import tilelang.language as T import tilelang.testing +from tilelang.transform import PassConfigKey def _check(original, transformed): @@ -11,6 +12,14 @@ def _check(original, transformed): tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), map_free_vars=True) +def _check_with_config(original, transformed, config): + func = original + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + with tvm.transform.PassContext(config=config): + mod = tl.transform.LoopUnswitching()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), map_free_vars=True) + + def test_basic_hoist(): """Basic case: loop-invariant if should be hoisted outside the loop.""" @@ -41,7 +50,7 @@ def expected( def test_hoist_with_else(): - """If with else branch should be hoisted with both branches.""" + """Conservative: if with non-trivial else should NOT be hoisted.""" @T.prim_func def before( @@ -61,11 +70,11 @@ def expected( B: T.Tensor((128,), T.float32), cond: T.Tensor((1,), T.int32), ): - if cond[0] > 0: - for i in range(128): + # Should remain unchanged + for i in range(128): + if cond[0] > 0: B[i] = A[i] - else: - for i in range(128): + else: B[i] = A[i] * T.float32(2.0) _check(before, expected) @@ -124,7 +133,7 @@ def expected( def test_hoist_with_other_stmts(): - """If with other statements in loop body.""" + """Conservative: if with other side-effecting statements should NOT be hoisted.""" @T.prim_func def before( @@ -145,14 +154,11 @@ def expected( C: T.Tensor((128,), T.float32), cond: T.Tensor((1,), T.int32), ): - if cond[0] > 0: - for i in range(128): - C[i] = A[i] + # Should remain unchanged + for i in range(128): + C[i] = A[i] + if cond[0] > 0: B[i] = A[i] - else: - for i in range(128): - C[i] = A[i] - T.evaluate(0) _check(before, expected) @@ -321,7 +327,7 @@ def expected( def test_multiple_identical_conditions_with_else(): - """Multiple if-else statements with the same condition.""" + """Conservative: multiple if-else statements should NOT be hoisted.""" @T.prim_func def before( @@ -347,13 +353,15 @@ def expected( C: T.Tensor((128,), T.float32), cond: T.Tensor((1,), T.int32), ): - if cond[0] > 0: - for i in range(128): + # Should remain unchanged + for i in range(128): + if cond[0] > 0: B[i] = A[i] - C[i] = A[i] * T.float32(2.0) - else: - for i in range(128): + else: B[i] = T.float32(0) + if cond[0] > 0: + C[i] = A[i] * T.float32(2.0) + else: C[i] = T.float32(1) _check(before, expected) @@ -400,5 +408,122 @@ def get_fused_mapping_kernel(topk_idx: T.Tensor[(1,), T.int32]): get_fused_mapping_kernel.compile() +def test_no_hoist_thread_idx_predicate(): + """Do not unswitch predicates that depend on threadIdx. + + These predicates are loop-invariant, but hoisting them can split execution + across threads and break later synchronization insertion passes. + """ + + @T.prim_func + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (256,), dtype=T.int32) + B = T.match_buffer(B_ptr, (256,), dtype=T.int32) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(256, thread="threadIdx.x"): + for i in T.unroll(0, 2): + B[tx] = A[tx] + if tx == 0: + B[i] = T.int32(1) + + _check(before, before) + + +def test_hoist_with_else_when_enabled(): + """Allow hoisting if-else when explicitly enabled.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + for i in range(128): + if cond[0] > 0: + B[i] = A[i] + else: + B[i] = A[i] * T.float32(2.0) + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + if cond[0] > 0: + for i in range(128): + B[i] = A[i] + else: + for i in range(128): + B[i] = A[i] * T.float32(2.0) + + _check_with_config( + before, + expected, + config={PassConfigKey.TL_LOOP_UNSWITCHING_ALLOW_NON_TRIVIAL_ELSE: True}, + ) + + +def test_hoist_with_other_stmts_when_enabled(): + """Allow hoisting when loop contains other side effects if enabled.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + for i in range(128): + C[i] = A[i] + if cond[0] > 0: + B[i] = A[i] + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + if cond[0] > 0: + for i in range(128): + C[i] = A[i] + B[i] = A[i] + else: + for i in range(128): + C[i] = A[i] + T.evaluate(0) + + _check_with_config( + before, + expected, + config={PassConfigKey.TL_LOOP_UNSWITCHING_ALLOW_NON_TRIVIAL_ELSE: True}, + ) + + +def test_no_hoist_thread_idx_predicate_even_when_enabled(): + """The aggressive option must not unswitch per-thread predicates.""" + + @T.prim_func + def before(A_ptr: T.handle, B_ptr: T.handle): + A = T.match_buffer(A_ptr, (256,), dtype=T.int32) + B = T.match_buffer(B_ptr, (256,), dtype=T.int32) + + for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(256, thread="threadIdx.x"): + for i in T.unroll(0, 2): + B[tx] = A[tx] + if tx == 0: + B[i] = T.int32(1) + + _check_with_config( + before, + before, + config={PassConfigKey.TL_LOOP_UNSWITCHING_ALLOW_NON_TRIVIAL_ELSE: True}, + ) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 5a98c06d0..945fe73a5 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -121,6 +121,12 @@ class PassConfigKey(str, Enum): TL_DISABLE_LOOP_UNSWITCHING = "tl.disable_loop_unswitching" """Disable loop unswitching optimization. Default: False""" + TL_LOOP_UNSWITCHING_ALLOW_NON_TRIVIAL_ELSE = "tl.loop_unswitching_allow_non_trivial_else" + """Allow loop unswitching even when the else-version of the loop body has side effects. + + This is more aggressive and may increase code size. Default: False. + """ + TL_DISABLE_THREAD_STORAGE_SYNC = "tl.disable_thread_storage_sync" """Disable thread storage synchronization pass. When enabled, disables the automatic insertion of thread synchronization barriers (e.g., __syncthreads())