Skip to content

Commit

Permalink
[TensorIR][Bugfix] Disallow fusing loops with dependency (#9112)
Browse files Browse the repository at this point in the history
* check dependency for fuse

* blank line
  • Loading branch information
jinhongyii authored Sep 25, 2021
1 parent 8f5abaa commit b33a1a7
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 7 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class ScheduleNode : public runtime::Object {
* 1) The loops can't have annotations or thread bindings.
* 2) The (i+1)-th loop must be the only child of the i-th loop.
* 3) All loops must start with 0.
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param loop_rvs The loops to be fused
* \return The new loop after fusion
*/
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
4) The domain of a loop to be fused cannot depend on another loop to be fused.
Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
* 1) The loops can't have annotations or thread bindings.
* 2) The inner loop must be the only child of the outer loop.
* 3) All loops must start with 0.
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be fused
* \return The sref to the fused loop
Expand Down
39 changes: 32 additions & 7 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,26 @@ class LoopsNotAChainError : public ScheduleError {

class DependentLoopError : public ScheduleError {
public:
explicit DependentLoopError(IRModule mod, For loop, String inner_var)
: mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {}
enum class PrimitiveKind { kFuse, kReorder };
explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind)
: mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {}

String FastErrorString() const final {
return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop "
"in the new order";
if (kind_ == PrimitiveKind::kReorder) {
return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop "
"in the new order";
} else {
return "ScheduleError: A loop's `extent` is dependent on another loop";
}
}

String DetailRenderTemplate() const final {
return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ +
" in the new order";
if (kind_ == PrimitiveKind::kReorder) {
return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ +
" in the new order";
} else {
return "A loop {0}'s `extent` is dependent on another loop " + inner_var_;
}
}

IRModule mod() const final { return mod_; }
Expand All @@ -377,6 +386,7 @@ class DependentLoopError : public ScheduleError {
IRModule mod_;
For loop_;
String inner_var_;
PrimitiveKind kind_;
};

Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
Expand Down Expand Up @@ -450,6 +460,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
StmtSRef outer_loop_sref{nullptr};
const ForNode* outer_loop = nullptr;
arith::Analyzer analyzer;
std::unordered_set<const VarNode*> outer_loop_vars;
// Step 1. check correctness
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
Expand All @@ -469,6 +480,19 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
if (!analyzer.CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
const VarNode* used_var = nullptr;
auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) {
if (outer_loop_vars.count(var)) {
used_var = var;
return true;
}
return false;
};
if (UsesVar(loop->extent, f_contain)) {
throw DependentLoopError(self->mod, GetRef<For>(loop), used_var->name_hint,
DependentLoopError::PrimitiveKind::kFuse);
}
outer_loop_vars.insert(loop->loop_var.get());
loops.push_back(loop);
}
// Step 2. Create fused loop var and replace the original loop vars
Expand Down Expand Up @@ -651,7 +675,8 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vector<const StmtSRefN
return false;
};
if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) {
throw DependentLoopError(self->mod, GetRef<For>(copy), used_var->name_hint);
throw DependentLoopError(self->mod, GetRef<For>(copy), used_var->name_hint,
DependentLoopError::PrimitiveKind::kReorder);
}
inner_vars.insert(copy->loop_var.get());
new_loop = For(std::move(n));
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def elementwise(a: ty.handle, b: ty.handle) -> None:
B[vi, vj, vk] = A[vi, vj, vk] * 2.0


@tvm.script.tir
def elementwise_dependent_loops(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128, 128))
B = tir.match_buffer(b, (128, 128, 128))
for i in tir.serial(0, 128):
for j, k in tir.grid(i, 128):
with tir.block([128, i, 128], "B") as [vi, vj, vk]:
B[vi, vj, vk] = A[vi, vj, vk] * 2.0


@tvm.script.tir
def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None:
A = tir.match_buffer(a, (128, 128, n))
Expand Down Expand Up @@ -462,5 +472,13 @@ def test_split_symbolic():
verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic)


def test_fuse_fail_with_dependent_loops():
sch = tir.Schedule(elementwise_dependent_loops, debug_mask="all")
block_b = sch.get_block("B")
i, j, _ = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError):
sch.fuse(i, j)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit b33a1a7

Please sign in to comment.