Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTG, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTGPredicated, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableLoopUnswitching, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLoopUnswitchingAllowNonTrivialElse, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableOutOfBoundWarning, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }
Expand Down
4 changes: 4 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kDisableLoopUnswitching =
"tl.disable_loop_unswitching";
// Allow loop unswitching even when the else-version of the loop body is
// non-trivial (has side effects). Default: false (conservative).
static constexpr const char *kLoopUnswitchingAllowNonTrivialElse =
"tl.loop_unswitching_allow_non_trivial_else";

/*!
* \brief Enable lowering non-predicated global load/store to ldg/stg intrinsics
Expand Down
234 changes: 215 additions & 19 deletions src/transform/loop_unswitching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,156 @@ bool UsesLoopVarThroughLetBindings(
return false;
}

/*!
* \brief Check if an expression uses any variable in \p vars (directly or
* through Let bindings).
*
* This is similar to UsesLoopVarThroughLetBindings, but generalized to a set of
* variables. It is used to conservatively block unswitching on per-thread
* predicates (e.g. threadIdx.x) because later passes may insert synchronization
* calls that would become control-flow dependent after unswitching.
*/
bool UsesVarsThroughLetBindingsImpl(
const PrimExpr &expr, const std::unordered_set<const VarNode *> &vars,
const std::unordered_map<const VarNode *, PrimExpr> *let_bindings,
std::unordered_set<const VarNode *> *visited_let_vars) {
if (vars.empty()) {
return false;
}

// Direct use in expr
if (UsesVar(expr, [&](const VarNode *v) { return vars.count(v); })) {
return true;
}

if (!let_bindings) {
return false;
}

bool uses = false;
PostOrderVisit(expr, [&](const ObjectRef &obj) {
if (uses) {
return;
}
const auto *var_node = obj.as<VarNode>();
if (!var_node) {
return;
}
auto it = let_bindings->find(var_node);
if (it == let_bindings->end()) {
return;
}
if (visited_let_vars && visited_let_vars->count(var_node)) {
return;
}
if (visited_let_vars) {
visited_let_vars->insert(var_node);
}
if (UsesVarsThroughLetBindingsImpl(it->second, vars, let_bindings,
visited_let_vars)) {
uses = true;
}
});

return uses;
}

bool UsesVarsThroughLetBindings(
const PrimExpr &expr, const std::unordered_set<const VarNode *> &vars,
const std::unordered_map<const VarNode *, PrimExpr> *let_bindings) {
std::unordered_set<const VarNode *> visited_let_vars;
return UsesVarsThroughLetBindingsImpl(expr, vars, let_bindings,
&visited_let_vars);
}

/*!
* \brief Check if a statement is side-effect free (i.e. a no-op), allowing only
* pure/read-only expression evaluation.
*
* This is intentionally conservative, and is used as a profitability/safety
* guard: only unswitch when the "else version" of the loop body does not
* perform any meaningful work. This keeps the common pattern
*
* for i: if cond: S(i)
*
* while avoiding code-size blowup and control-flow complexity for
*
* for i: if cond: S1(i) else: S2(i)
*
* or when there are other side-effecting statements outside the hoisted if.
*/
bool IsSideEffectFreeStmt(const Stmt &stmt) {
if (!stmt.defined()) {
return true;
}

if (const auto *op = stmt.as<EvaluateNode>()) {
// Treat pure or read-only evaluation as no-op.
return SideEffect(op->value) <= CallEffectKind::kReadState;
}

if (const auto *op = stmt.as<SeqStmtNode>()) {
for (const Stmt &s : op->seq) {
if (!IsSideEffectFreeStmt(s)) {
return false;
}
}
return true;
}

if (const auto *op = stmt.as<LetStmtNode>()) {
if (SideEffect(op->value) > CallEffectKind::kReadState) {
return false;
}
return IsSideEffectFreeStmt(op->body);
}

if (const auto *op = stmt.as<IfThenElseNode>()) {
if (SideEffect(op->condition) > CallEffectKind::kReadState) {
return false;
}
if (!IsSideEffectFreeStmt(op->then_case)) {
return false;
}
if (op->else_case.defined() &&
!IsSideEffectFreeStmt(op->else_case.value())) {
return false;
}
return true;
}

if (const auto *op = stmt.as<ForNode>()) {
if (SideEffect(op->min) > CallEffectKind::kReadState ||
SideEffect(op->extent) > CallEffectKind::kReadState) {
return false;
}
return IsSideEffectFreeStmt(op->body);
}

// Conservatively treat all other statements as side-effecting.
return false;
}

/*!
* \brief Check if a condition is loop-invariant
*/
bool IsLoopInvariant(const PrimExpr &cond, const Var &loop_var,
const std::unordered_set<const VarNode *> &written_vars,
const std::unordered_map<const VarNode *, PrimExpr>
*let_bindings = nullptr) {
bool IsLoopInvariant(
const PrimExpr &cond, const Var &loop_var,
const std::unordered_set<const VarNode *> &written_vars,
const std::unordered_map<const VarNode *, PrimExpr> *let_bindings = nullptr,
const std::unordered_set<const VarNode *> *disallowed_vars = nullptr) {
// Check 0: disallow conditions that depend on per-thread binding vars (e.g.
// threadIdx.x). These predicates are loop-invariant, but unswitching them can
// split the execution into different code paths across threads. Later passes
// (e.g. thread sync insertion, fence proxy injection) may add synchronization
// calls outside the hoisted if, which would become control-flow dependent and
// lead to incorrect codegen.
if (disallowed_vars && !disallowed_vars->empty()) {
if (UsesVarsThroughLetBindings(cond, *disallowed_vars, let_bindings)) {
return false;
}
}

// Check 1: must not use loop variable (directly or through Let bindings)
if (UsesLoopVarThroughLetBindings(cond, loop_var, let_bindings)) {
return false;
Expand Down Expand Up @@ -329,13 +472,16 @@ class HoistableIfFinder : public StmtVisitor {
const IfThenElseNode *found = nullptr;
const Var &loop_var;
const std::unordered_set<const VarNode *> &written_vars;
const std::unordered_set<const VarNode *> *disallowed_vars;
std::unordered_map<const VarNode *, PrimExpr> let_bindings_;
// Let bindings that need to be hoisted with the condition
std::vector<std::pair<Var, PrimExpr>> hoisted_let_bindings;

HoistableIfFinder(const Var &loop_var,
const std::unordered_set<const VarNode *> &written_vars)
: loop_var(loop_var), written_vars(written_vars) {}
const std::unordered_set<const VarNode *> &written_vars,
const std::unordered_set<const VarNode *> *disallowed_vars)
: loop_var(loop_var), written_vars(written_vars),
disallowed_vars(disallowed_vars) {}

void VisitStmt_(const LetStmtNode *op) final {
// Track ALL Let bindings to detect when a condition uses a variable
Expand All @@ -352,8 +498,8 @@ class HoistableIfFinder : public StmtVisitor {
void VisitStmt_(const IfThenElseNode *op) final {
if (found)
return;
if (IsLoopInvariant(op->condition, loop_var, written_vars,
&let_bindings_)) {
if (IsLoopInvariant(op->condition, loop_var, written_vars, &let_bindings_,
disallowed_vars)) {
found = op;
// Collect Let-bound variables used in the condition
LetVarCollector collector(let_bindings_);
Expand All @@ -374,7 +520,22 @@ class HoistableIfFinder : public StmtVisitor {
*/
class LoopUnswitcher : public StmtExprMutator {
public:
explicit LoopUnswitcher(bool allow_non_trivial_else)
: allow_non_trivial_else_(allow_non_trivial_else) {}

std::unordered_set<const VarNode *> thread_idx_vars_in_scope_;

Stmt VisitStmt_(const ForNode *op) final {
bool pushed_thread_idx = false;
if (op->thread_binding.defined()) {
String thread_tag = op->thread_binding.value()->thread_tag;
if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
thread_tag == "threadIdx.z") {
thread_idx_vars_in_scope_.insert(op->loop_var.get());
pushed_thread_idx = true;
}
}

// Bottom-up: process nested structures first
Stmt body = VisitStmt(op->body);

Expand All @@ -383,15 +544,22 @@ class LoopUnswitcher : public StmtExprMutator {
collector(body);

// Find hoistable if
HoistableIfFinder finder(op->loop_var, collector.written);
HoistableIfFinder finder(op->loop_var, collector.written,
&thread_idx_vars_in_scope_);
finder(body);

Stmt result;
if (!finder.found) {
if (body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
result = ffi::GetRef<Stmt>(op);
} else {
result = For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
}
if (pushed_thread_idx) {
thread_idx_vars_in_scope_.erase(op->loop_var.get());
}
return For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
return result;
}

// Check if there are any function calls OUTSIDE the hoisted if statement.
Expand All @@ -403,10 +571,15 @@ class LoopUnswitcher : public StmtExprMutator {
call_checker(body);
if (call_checker.has_call) {
if (body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
result = ffi::GetRef<Stmt>(op);
} else {
result = For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
}
return For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
if (pushed_thread_idx) {
thread_idx_vars_in_scope_.erase(op->loop_var.get());
}
return result;
}

// Unswitch: create two loop versions
Expand All @@ -418,6 +591,19 @@ class LoopUnswitcher : public StmtExprMutator {
Stmt else_body = IfBranchReplacer(hoisted_condition, false,
finder.hoisted_let_bindings)(body);

// Only unswitch when the else-version does not do any meaningful work.
// This keeps the canonical optimization `for: if(cond) {S}` ->
// `if(cond){for:S}` while avoiding duplicating non-trivial loop bodies into
// two versions.
if (!allow_non_trivial_else_ && !IsSideEffectFreeStmt(else_body)) {
result = For(op->loop_var, op->min, op->extent, op->kind, body,
op->thread_binding, op->annotations);
if (pushed_thread_idx) {
thread_idx_vars_in_scope_.erase(op->loop_var.get());
}
return result;
}

// Create new loop_var for else_loop to maintain SSA form
Var else_loop_var(op->loop_var->name_hint, op->loop_var->dtype);
else_body = Substitute(else_body, {{op->loop_var, else_loop_var}});
Expand All @@ -427,7 +613,7 @@ class LoopUnswitcher : public StmtExprMutator {
For else_loop(else_loop_var, op->min, op->extent, op->kind, else_body,
op->thread_binding, op->annotations);

Stmt result = IfThenElse(if_node->condition, then_loop, else_loop);
result = IfThenElse(if_node->condition, then_loop, else_loop);

// Wrap with hoisted Let bindings (in reverse order so first binding is
// outermost)
Expand All @@ -436,14 +622,20 @@ class LoopUnswitcher : public StmtExprMutator {
result = LetStmt(it->first, it->second, result);
}

if (pushed_thread_idx) {
thread_idx_vars_in_scope_.erase(op->loop_var.get());
}
return result;
}

private:
bool allow_non_trivial_else_{false};
};

// --- Public API ---

Stmt ApplyLoopUnswitching(Stmt stmt) {
return LoopUnswitcher()(std::move(stmt));
Stmt ApplyLoopUnswitching(Stmt stmt, bool allow_non_trivial_else) {
return LoopUnswitcher(allow_non_trivial_else)(std::move(stmt));
}

using namespace tir::transform;
Expand All @@ -455,7 +647,11 @@ tvm::transform::Pass LoopUnswitching() {
if (disable_loop_unswitching) {
return f;
}
f.CopyOnWrite()->body = ApplyLoopUnswitching(f->body);
bool allow_non_trivial_else =
ctx->GetConfig<Bool>(kLoopUnswitchingAllowNonTrivialElse, Bool(false))
.value();
f.CopyOnWrite()->body =
ApplyLoopUnswitching(f->body, allow_non_trivial_else);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LoopUnswitching", {});
Expand Down
Loading
Loading