diff --git a/src/transform/loop_unswitching.cc b/src/transform/loop_unswitching.cc index 0adbc5ac7..aab8f7fdc 100644 --- a/src/transform/loop_unswitching.cc +++ b/src/transform/loop_unswitching.cc @@ -144,7 +144,8 @@ class CallNodeChecker : public ExprVisitor { }; /*! - * \brief Check if a statement contains any CallNode, excluding a specific If + * \brief Check if a statement contains any CallNode, excluding matching If + * nodes * * Loop unswitching is unsafe when there are function calls OUTSIDE the * hoisted if statement, because those calls (originally executed by all @@ -153,15 +154,19 @@ class CallNodeChecker : public ExprVisitor { * * Calls INSIDE the if are safe because they were already conditionally * executed before unswitching. + * + * Since we replace ALL if statements with matching conditions, we need to + * exclude all such if statements when checking for calls. */ class CallCheckerExcludingIf : public StmtExprVisitor { public: bool has_call = false; - const IfThenElseNode *excluded_if = nullptr; + PrimExpr excluded_condition; void VisitStmt_(const IfThenElseNode *op) final { - if (op == excluded_if) { - // Skip the interior of the excluded if statement + // Skip the interior of any if statement with matching condition + if (excluded_condition.defined() && + StructuralEqual()(op->condition, excluded_condition)) { return; } StmtExprVisitor::VisitStmt_(op); @@ -237,18 +242,35 @@ bool IsLoopInvariant(const PrimExpr &cond, const Var &loop_var, } /*! - * \brief Replace a specific if node with its then/else branch + * \brief Replace if nodes with matching condition with their then/else branch + * + * When hoisting a condition out of a loop, we need to replace ALL if statements + * with the same condition, not just the first one found. This ensures that + * in the then-branch all matching conditions are replaced with their then-case, + * and in the else-branch all matching conditions are replaced with their + * else-case. + * + * Also removes LetStmts for variables that have been hoisted, since they are + * now redundant (the variable is already bound outside the loop). */ class IfBranchReplacer : public StmtExprMutator { public: - const IfThenElseNode *target; + PrimExpr hoisted_condition; bool take_then; - - IfBranchReplacer(const IfThenElseNode *target, bool take_then) - : target(target), take_then(take_then) {} + std::unordered_set hoisted_vars; + + IfBranchReplacer( + const PrimExpr &condition, bool take_then, + const std::vector> &hoisted_let_bindings) + : hoisted_condition(condition), take_then(take_then) { + for (const auto &binding : hoisted_let_bindings) { + hoisted_vars.insert(binding.first.get()); + } + } Stmt VisitStmt_(const IfThenElseNode *op) final { - if (op == target) { + // Replace if the condition is structurally equal to the hoisted condition + if (StructuralEqual()(op->condition, hoisted_condition)) { if (take_then) { return VisitStmt(op->then_case); } else { @@ -258,6 +280,43 @@ class IfBranchReplacer : public StmtExprMutator { } return StmtExprMutator::VisitStmt_(op); } + + Stmt VisitStmt_(const LetStmtNode *op) final { + // Remove LetStmts for hoisted variables (they are now bound outside the + // loop) + if (hoisted_vars.count(op->var.get())) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + +/*! + * \brief Collect Let-bound variables used in an expression + */ +class LetVarCollector : public ExprVisitor { +public: + std::vector> used_let_bindings; + const std::unordered_map &let_bindings; + std::unordered_set visited; + + explicit LetVarCollector( + const std::unordered_map &bindings) + : let_bindings(bindings) {} + + void VisitExpr_(const VarNode *op) final { + if (visited.count(op)) + return; + auto it = let_bindings.find(op); + if (it != let_bindings.end()) { + visited.insert(op); + // First recursively collect Let-bound vars used in this binding's value + VisitExpr(it->second); + // Then add this binding (so dependencies come first) + used_let_bindings.push_back( + std::make_pair(ffi::GetRef(op), it->second)); + } + } }; /*! @@ -271,6 +330,8 @@ class HoistableIfFinder : public StmtVisitor { const Var &loop_var; const std::unordered_set &written_vars; std::unordered_map let_bindings_; + // Let bindings that need to be hoisted with the condition + std::vector> hoisted_let_bindings; HoistableIfFinder(const Var &loop_var, const std::unordered_set &written_vars) @@ -294,6 +355,10 @@ class HoistableIfFinder : public StmtVisitor { if (IsLoopInvariant(op->condition, loop_var, written_vars, &let_bindings_)) { found = op; + // Collect Let-bound variables used in the condition + LetVarCollector collector(let_bindings_); + collector(op->condition); + hoisted_let_bindings = std::move(collector.used_let_bindings); return; } StmtVisitor::VisitStmt_(op); @@ -334,7 +399,7 @@ class LoopUnswitcher : public StmtExprMutator { // would split them into different code paths, breaking synchronization. // Calls inside the if are already conditionally executed, so they're safe. CallCheckerExcludingIf call_checker; - call_checker.excluded_if = finder.found; + call_checker.excluded_condition = finder.found->condition; call_checker(body); if (call_checker.has_call) { if (body.same_as(op->body)) { @@ -346,9 +411,12 @@ class LoopUnswitcher : public StmtExprMutator { // Unswitch: create two loop versions const IfThenElseNode *if_node = finder.found; + PrimExpr hoisted_condition = if_node->condition; - Stmt then_body = IfBranchReplacer(if_node, true)(body); - Stmt else_body = IfBranchReplacer(if_node, false)(body); + Stmt then_body = IfBranchReplacer(hoisted_condition, true, + finder.hoisted_let_bindings)(body); + Stmt else_body = IfBranchReplacer(hoisted_condition, false, + finder.hoisted_let_bindings)(body); // Create new loop_var for else_loop to maintain SSA form Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype); @@ -359,7 +427,16 @@ class LoopUnswitcher : public StmtExprMutator { For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body, op->thread_binding, op->annotations); - return IfThenElse(if_node->condition, then_loop, else_loop); + Stmt result = IfThenElse(if_node->condition, then_loop, else_loop); + + // Wrap with hoisted Let bindings (in reverse order so first binding is + // outermost) + for (auto it = finder.hoisted_let_bindings.rbegin(); + it != finder.hoisted_let_bindings.rend(); ++it) { + result = LetStmt(it->first, it->second, result); + } + + return result; } }; diff --git a/testing/python/transform/test_tilelang_transform_loop_unswitching.py b/testing/python/transform/test_tilelang_transform_loop_unswitching.py index 8a4bc6214..91ff355b9 100644 --- a/testing/python/transform/test_tilelang_transform_loop_unswitching.py +++ b/testing/python/transform/test_tilelang_transform_loop_unswitching.py @@ -219,5 +219,172 @@ def expected( _check(before, expected) +def test_hoist_let_bound_variable(): + """If condition uses a Let-bound variable, both should be hoisted together.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((1,), T.float32), + ): + for i in range(128): + pos = C[0] + if pos >= T.float32(0): + B[i] = A[i] + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((1,), T.float32), + ): + # Let binding is hoisted before the if, redundant inner LetStmt is removed + pos = C[0] + if pos >= T.float32(0): + for i in range(128): + B[i] = A[i] + else: + for _i in range(128): + T.evaluate(0) + + _check(before, expected) + + +def test_hoist_multiple_let_bound_variables(): + """If condition uses multiple Let-bound variables, all should be hoisted.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((2,), T.float32), + ): + for i in range(128): + x = C[0] + y = C[1] + if x + y >= T.float32(0): + B[i] = A[i] + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((2,), T.float32), + ): + # Let bindings are hoisted before the if, redundant inner LetStmts are removed + x = C[0] + y = C[1] + if x + y >= T.float32(0): + for i in range(128): + B[i] = A[i] + else: + for _i in range(128): + T.evaluate(0) + + _check(before, expected) + + +def test_multiple_identical_conditions(): + """Multiple if statements with the same condition should all be replaced.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + for i in range(128): + if cond[0] > 0: + B[i] = A[i] + if cond[0] > 0: + C[i] = A[i] * T.float32(2.0) + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + if cond[0] > 0: + for i in range(128): + B[i] = A[i] + C[i] = A[i] * T.float32(2.0) + else: + for _i in range(128): + T.evaluate(0) + T.evaluate(0) + + _check(before, expected) + + +def test_multiple_identical_conditions_with_else(): + """Multiple if-else statements with the same condition.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + for i in range(128): + if cond[0] > 0: + B[i] = A[i] + else: + B[i] = T.float32(0) + if cond[0] > 0: + C[i] = A[i] * T.float32(2.0) + else: + C[i] = T.float32(1) + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + C: T.Tensor((128,), T.float32), + cond: T.Tensor((1,), T.int32), + ): + if cond[0] > 0: + for i in range(128): + B[i] = A[i] + C[i] = A[i] * T.float32(2.0) + else: + for i in range(128): + B[i] = T.float32(0) + C[i] = T.float32(1) + + _check(before, expected) + + +def test_no_hoist_let_bound_loop_variant(): + """Let-bound variable depends on loop var, condition should NOT be hoisted.""" + + @T.prim_func + def before( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + for i in range(128): + idx = i % 2 + if idx == 0: + B[i] = A[i] + + @T.prim_func + def expected( + A: T.Tensor((128,), T.float32), + B: T.Tensor((128,), T.float32), + ): + # Should remain unchanged since idx depends on loop variable i + for i in range(128): + idx = i % 2 + if idx == 0: + B[i] = A[i] + + _check(before, expected) + + if __name__ == "__main__": tilelang.testing.main()