From cae0b9a61637c831e9b61da90e063a1c623eb205 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 25 Sep 2021 12:06:27 +0800 Subject: [PATCH 1/2] check dependency for fuse --- include/tvm/tir/schedule/schedule.h | 1 + python/tvm/tir/schedule/schedule.py | 2 +- src/tir/schedule/primitive.h | 1 + .../schedule/primitive/loop_transformation.cc | 39 +++++++++++++++---- .../unittest/test_tir_schedule_split_fuse.py | 18 +++++++++ 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 66dd5375eaf9..9f48d9ab9b1f 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -216,6 +216,7 @@ class ScheduleNode : public runtime::Object { * 1) The loops can't have annotations or thread bindings. * 2) The (i+1)-th loop must be the only child of the i-th loop. * 3) All loops must start with 0. + * 4) The domain of a loop to be fused cannot depend on another loop to be fused. * \param loop_rvs The loops to be fused * \return The new loop after fusion */ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7545c09b020d..979364e8941e 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -373,7 +373,7 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: 1) The loops can't have annotations or thread bindings. 2) The (i+1)-th loop must be the only child of the i-th loop. 3) All loops must start with 0. - + 4) 4) The domain of a loop to be fused cannot depend on another loop to be fused. Parameters ---------- *loops : List[LoopRV] diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 05eefaca8a11..8ad6bdf7d37f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -72,6 +72,7 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * 1) The loops can't have annotations or thread bindings. * 2) The inner loop must be the only child of the outer loop. * 3) All loops must start with 0. + * 4) The domain of a loop to be fused cannot depend on another loop to be fused. * \param self The state of the schedule * \param loop_srefs An array of srefs to the loops to be fused * \return The sref to the fused loop diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 95c92aa0a322..7b9ac488b8b9 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -358,17 +358,26 @@ class LoopsNotAChainError : public ScheduleError { class DependentLoopError : public ScheduleError { public: - explicit DependentLoopError(IRModule mod, For loop, String inner_var) - : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + enum class PrimitiveKind { kFuse, kReorder }; + explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind) + : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {} String FastErrorString() const final { - return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " - "in the new order"; + if (kind_ == PrimitiveKind::kReorder) { + return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " + "in the new order"; + } else { + return "ScheduleError: A loop's `extent` is dependent on another loop"; + } } String DetailRenderTemplate() const final { - return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + - " in the new order"; + if (kind_ == PrimitiveKind::kReorder) { + return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + + " in the new order"; + } else { + return "A loop {0}'s `extent` is dependent on another loop " + inner_var_; + } } IRModule mod() const final { return mod_; } @@ -377,6 +386,7 @@ class DependentLoopError : public ScheduleError { IRModule mod_; For loop_; String inner_var_; + PrimitiveKind kind_; }; Array Split(ScheduleState self, const StmtSRef& loop_sref, @@ -450,6 +460,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { StmtSRef outer_loop_sref{nullptr}; const ForNode* outer_loop = nullptr; arith::Analyzer analyzer; + std::unordered_set outer_loop_vars; // Step 1. check correctness for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); @@ -469,6 +480,19 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { if (!analyzer.CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } + const VarNode* used_var = nullptr; + auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) { + if (outer_loop_vars.count(var)) { + used_var = var; + return true; + } + return false; + }; + if (UsesVar(loop->extent, f_contain)) { + throw DependentLoopError(self->mod, GetRef(loop), used_var->name_hint, + DependentLoopError::PrimitiveKind::kFuse); + } + outer_loop_vars.insert(loop->loop_var.get()); loops.push_back(loop); } // Step 2. Create fused loop var and replace the original loop vars @@ -651,7 +675,8 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectormin, f_contain) || UsesVar(copy->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); + throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint, + DependentLoopError::PrimitiveKind::kReorder); } inner_vars.insert(copy->loop_var.get()); new_loop = For(std::move(n)); diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 2284f9d996b1..d11e7f877ccc 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -34,6 +34,16 @@ def elementwise(a: ty.handle, b: ty.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 +@tvm.script.tir +def elementwise_dependent_loops(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i in tir.serial(0, 128): + for j, k in tir.grid(i, 128): + with tir.block([128, i, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + @tvm.script.tir def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) @@ -462,5 +472,13 @@ def test_split_symbolic(): verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) +def test_fuse_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_dependent_loops, debug_mask="all") + block_b = sch.get_block("B") + i, j, _ = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(i, j) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From bd2c3590a443499f502356e913bf6758639ea794 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 25 Sep 2021 12:10:09 +0800 Subject: [PATCH 2/2] blank line --- python/tvm/tir/schedule/schedule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 979364e8941e..d26ffc0b1efa 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -373,7 +373,8 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: 1) The loops can't have annotations or thread bindings. 2) The (i+1)-th loop must be the only child of the i-th loop. 3) All loops must start with 0. - 4) 4) The domain of a loop to be fused cannot depend on another loop to be fused. + 4) The domain of a loop to be fused cannot depend on another loop to be fused. + Parameters ---------- *loops : List[LoopRV]