Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 91 additions & 14 deletions src/transform/loop_unswitching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class CallNodeChecker : public ExprVisitor {
};

/*!
* \brief Check if a statement contains any CallNode, excluding a specific If
* \brief Check if a statement contains any CallNode, excluding matching If
* nodes
*
* Loop unswitching is unsafe when there are function calls OUTSIDE the
* hoisted if statement, because those calls (originally executed by all
Expand All @@ -153,15 +154,19 @@ class CallNodeChecker : public ExprVisitor {
*
* Calls INSIDE the if are safe because they were already conditionally
* executed before unswitching.
*
* Since we replace ALL if statements with matching conditions, we need to
* exclude all such if statements when checking for calls.
*/
class CallCheckerExcludingIf : public StmtExprVisitor {
public:
bool has_call = false;
const IfThenElseNode *excluded_if = nullptr;
PrimExpr excluded_condition;

void VisitStmt_(const IfThenElseNode *op) final {
if (op == excluded_if) {
// Skip the interior of the excluded if statement
// Skip the interior of any if statement with matching condition
if (excluded_condition.defined() &&
StructuralEqual()(op->condition, excluded_condition)) {
return;
}
StmtExprVisitor::VisitStmt_(op);
Expand Down Expand Up @@ -237,18 +242,35 @@ bool IsLoopInvariant(const PrimExpr &cond, const Var &loop_var,
}

/*!
* \brief Replace a specific if node with its then/else branch
* \brief Replace if nodes with matching condition with their then/else branch
*
* When hoisting a condition out of a loop, we need to replace ALL if statements
* with the same condition, not just the first one found. This ensures that
* in the then-branch all matching conditions are replaced with their then-case,
* and in the else-branch all matching conditions are replaced with their
* else-case.
*
* Also removes LetStmts for variables that have been hoisted, since they are
* now redundant (the variable is already bound outside the loop).
*/
class IfBranchReplacer : public StmtExprMutator {
public:
const IfThenElseNode *target;
PrimExpr hoisted_condition;
bool take_then;

IfBranchReplacer(const IfThenElseNode *target, bool take_then)
: target(target), take_then(take_then) {}
std::unordered_set<const VarNode *> hoisted_vars;

IfBranchReplacer(
const PrimExpr &condition, bool take_then,
const std::vector<std::pair<Var, PrimExpr>> &hoisted_let_bindings)
: hoisted_condition(condition), take_then(take_then) {
for (const auto &binding : hoisted_let_bindings) {
hoisted_vars.insert(binding.first.get());
}
}

Stmt VisitStmt_(const IfThenElseNode *op) final {
if (op == target) {
// Replace if the condition is structurally equal to the hoisted condition
if (StructuralEqual()(op->condition, hoisted_condition)) {
if (take_then) {
return VisitStmt(op->then_case);
} else {
Expand All @@ -258,6 +280,43 @@ class IfBranchReplacer : public StmtExprMutator {
}
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const LetStmtNode *op) final {
// Remove LetStmts for hoisted variables (they are now bound outside the
// loop)
if (hoisted_vars.count(op->var.get())) {
return VisitStmt(op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
};

/*!
* \brief Collect Let-bound variables used in an expression
*/
class LetVarCollector : public ExprVisitor {
public:
std::vector<std::pair<Var, PrimExpr>> used_let_bindings;
const std::unordered_map<const VarNode *, PrimExpr> &let_bindings;
std::unordered_set<const VarNode *> visited;

explicit LetVarCollector(
const std::unordered_map<const VarNode *, PrimExpr> &bindings)
: let_bindings(bindings) {}

void VisitExpr_(const VarNode *op) final {
if (visited.count(op))
return;
auto it = let_bindings.find(op);
if (it != let_bindings.end()) {
visited.insert(op);
// First recursively collect Let-bound vars used in this binding's value
VisitExpr(it->second);
// Then add this binding (so dependencies come first)
used_let_bindings.push_back(
std::make_pair(ffi::GetRef<Var>(op), it->second));
}
}
};

/*!
Expand All @@ -271,6 +330,8 @@ class HoistableIfFinder : public StmtVisitor {
const Var &loop_var;
const std::unordered_set<const VarNode *> &written_vars;
std::unordered_map<const VarNode *, PrimExpr> let_bindings_;
// Let bindings that need to be hoisted with the condition
std::vector<std::pair<Var, PrimExpr>> hoisted_let_bindings;

HoistableIfFinder(const Var &loop_var,
const std::unordered_set<const VarNode *> &written_vars)
Expand All @@ -294,6 +355,10 @@ class HoistableIfFinder : public StmtVisitor {
if (IsLoopInvariant(op->condition, loop_var, written_vars,
&let_bindings_)) {
found = op;
// Collect Let-bound variables used in the condition
LetVarCollector collector(let_bindings_);
collector(op->condition);
hoisted_let_bindings = std::move(collector.used_let_bindings);
return;
}
StmtVisitor::VisitStmt_(op);
Expand Down Expand Up @@ -334,7 +399,7 @@ class LoopUnswitcher : public StmtExprMutator {
// would split them into different code paths, breaking synchronization.
// Calls inside the if are already conditionally executed, so they're safe.
CallCheckerExcludingIf call_checker;
call_checker.excluded_if = finder.found;
call_checker.excluded_condition = finder.found->condition;
call_checker(body);
if (call_checker.has_call) {
if (body.same_as(op->body)) {
Expand All @@ -346,9 +411,12 @@ class LoopUnswitcher : public StmtExprMutator {

// Unswitch: create two loop versions
const IfThenElseNode *if_node = finder.found;
PrimExpr hoisted_condition = if_node->condition;

Stmt then_body = IfBranchReplacer(if_node, true)(body);
Stmt else_body = IfBranchReplacer(if_node, false)(body);
Stmt then_body = IfBranchReplacer(hoisted_condition, true,
finder.hoisted_let_bindings)(body);
Stmt else_body = IfBranchReplacer(hoisted_condition, false,
finder.hoisted_let_bindings)(body);

// Create new loop_var for else_loop to maintain SSA form
Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype);
Expand All @@ -359,7 +427,16 @@ class LoopUnswitcher : public StmtExprMutator {
For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body,
op->thread_binding, op->annotations);

return IfThenElse(if_node->condition, then_loop, else_loop);
Stmt result = IfThenElse(if_node->condition, then_loop, else_loop);

// Wrap with hoisted Let bindings (in reverse order so first binding is
// outermost)
for (auto it = finder.hoisted_let_bindings.rbegin();
it != finder.hoisted_let_bindings.rend(); ++it) {
result = LetStmt(it->first, it->second, result);
}

return result;
}
};

Expand Down
167 changes: 167 additions & 0 deletions testing/python/transform/test_tilelang_transform_loop_unswitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,172 @@ def expected(
_check(before, expected)


def test_hoist_let_bound_variable():
"""If condition uses a Let-bound variable, both should be hoisted together."""

@T.prim_func
def before(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((1,), T.float32),
):
for i in range(128):
pos = C[0]
if pos >= T.float32(0):
B[i] = A[i]

@T.prim_func
def expected(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((1,), T.float32),
):
# Let binding is hoisted before the if, redundant inner LetStmt is removed
pos = C[0]
if pos >= T.float32(0):
for i in range(128):
B[i] = A[i]
else:
for _i in range(128):
T.evaluate(0)

_check(before, expected)


def test_hoist_multiple_let_bound_variables():
"""If condition uses multiple Let-bound variables, all should be hoisted."""

@T.prim_func
def before(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((2,), T.float32),
):
for i in range(128):
x = C[0]
y = C[1]
if x + y >= T.float32(0):
B[i] = A[i]

@T.prim_func
def expected(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((2,), T.float32),
):
# Let bindings are hoisted before the if, redundant inner LetStmts are removed
x = C[0]
y = C[1]
if x + y >= T.float32(0):
for i in range(128):
B[i] = A[i]
else:
for _i in range(128):
T.evaluate(0)

_check(before, expected)


def test_multiple_identical_conditions():
"""Multiple if statements with the same condition should all be replaced."""

@T.prim_func
def before(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((128,), T.float32),
cond: T.Tensor((1,), T.int32),
):
for i in range(128):
if cond[0] > 0:
B[i] = A[i]
if cond[0] > 0:
C[i] = A[i] * T.float32(2.0)

@T.prim_func
def expected(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((128,), T.float32),
cond: T.Tensor((1,), T.int32),
):
if cond[0] > 0:
for i in range(128):
B[i] = A[i]
C[i] = A[i] * T.float32(2.0)
else:
for _i in range(128):
T.evaluate(0)
T.evaluate(0)

_check(before, expected)


def test_multiple_identical_conditions_with_else():
"""Multiple if-else statements with the same condition."""

@T.prim_func
def before(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((128,), T.float32),
cond: T.Tensor((1,), T.int32),
):
for i in range(128):
if cond[0] > 0:
B[i] = A[i]
else:
B[i] = T.float32(0)
if cond[0] > 0:
C[i] = A[i] * T.float32(2.0)
else:
C[i] = T.float32(1)

@T.prim_func
def expected(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
C: T.Tensor((128,), T.float32),
cond: T.Tensor((1,), T.int32),
):
if cond[0] > 0:
for i in range(128):
B[i] = A[i]
C[i] = A[i] * T.float32(2.0)
else:
for i in range(128):
B[i] = T.float32(0)
C[i] = T.float32(1)

_check(before, expected)


def test_no_hoist_let_bound_loop_variant():
"""Let-bound variable depends on loop var, condition should NOT be hoisted."""

@T.prim_func
def before(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
):
for i in range(128):
idx = i % 2
if idx == 0:
B[i] = A[i]

@T.prim_func
def expected(
A: T.Tensor((128,), T.float32),
B: T.Tensor((128,), T.float32),
):
# Should remain unchanged since idx depends on loop variable i
for i in range(128):
idx = i % 2
if idx == 0:
B[i] = A[i]

_check(before, expected)


if __name__ == "__main__":
tilelang.testing.main()
Loading