From 852a30c8f7c93dbe0ef084a82783a16c5d26cf07 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 02:00:27 -0700 Subject: [PATCH 1/7] inint --- src/tir/schedule/analysis.h | 9 ++------- src/tir/schedule/analysis/analysis.cc | 19 +++++++++++-------- src/tir/schedule/primitive/for_kind.cc | 4 +--- .../unittest/test_tir_schedule_for_kind.py | 13 +++++++++++++ 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9c6d1e6e96da..d398f22ed467 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -168,16 +168,11 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl /*! * \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees * rooted at its direct children, and this property requires all the blocks of the subtree - * that the specified sref is in to be complete block or reduction block. + * that the specified sref is in to be local complete block or local reduction block. * \param self The schedule state * \param subtree_root The sref of the subtree root to be checked - * \param scope_root_sref The scope root of the block - * \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact - * dataflow condition, i.e. a block in the subtree is neither complete block nor - * reduction block */ -void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, - const StmtSRef& scope_root_sref); +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root); /*! * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is * not allocated under the current scope diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c7ed67187793..a2555b9ea763 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -207,14 +207,14 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block static const char* kCompleteBlockDefinition = R"(Definition of a complete block: 1) All block vars are data parallel -2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers under the given scope. 3) No overlap between the buffers the block reads and writes)"; static const char* kReductionBlockDefinition = R"(Definition of a reduction block: 1) The block has the `init` statement 2) All the block bindings are quasi-affine expressions 3) All block vars are either data parallel block vars or reduction block vars -4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers under the given scope. 5) The reduction block vars are not used to index the output buffers)"; bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -363,8 +363,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl reduction_block_error_code); } -void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, - const StmtSRef& scope_root_sref) { +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { class NotCompactDataFlowError : public ScheduleError { public: explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) @@ -375,12 +374,14 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt } String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " - "because some of its child block on SRef tree is neither a complete block nor a " + "because some of its child block on SRef tree is neither a local complete block nor a " + "local " "reduction block"; } String DetailRenderTemplate() const final { return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " - "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + "its child block {1} on SRef tree is neither a local complete block nor a local " + "reduction block"; } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } @@ -392,8 +393,10 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); for (const StmtSRef& block_sref : child_block_srefs) { - if (!IsCompleteBlock(self, block_sref, scope_root_sref) && - !IsReductionBlock(self, block_sref, scope_root_sref)) { + // Local complete: complete block under the subtree. + // Local reduction: reduction block under the subtree. + if (!IsCompleteBlock(self, block_sref, block_sref) && + !IsReductionBlock(self, block_sref, block_sref)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), GetRef(block)); diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 333d78346453..ec337224e59d 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -157,9 +157,7 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref * parallelized/vectorized/bound. */ // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. - StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref, - /*require_stage_pipeline=*/true); - CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref); + CheckSubtreeCompactDataflow(self, loop_sref); // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index caecde05b40f..2cf043ee4765 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -468,5 +468,18 @@ def test_vectorize_after_decompose(): verify_trace_roundtrip(s, mod=decomposed_gemm) +def test_compact_data_flow_local_complete_reduction(): + s = tir.Schedule(decomposed_gemm, debug_mask="all") + init_blk = s.get_block("init") + upd_blk = s.get_block("update") + ii_0, jj_0 = s.get_loops(init_blk) + k_1, ii_1, jj_1 = s.get_child_blocks(upd_blk) + s.vectorize(jj_0) + s.bind(jj_1, "threadIdx.x") + print(s.mod["main"].script()) + tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_double_bound) + verify_trace_roundtrip(s, mod=decomposed_gemm) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 92d351f8a914e3455186166d15ab37e7121c2616 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 02:55:52 -0700 Subject: [PATCH 2/7] upd --- src/tir/schedule/analysis/analysis.cc | 11 ++- .../unittest/test_tir_schedule_for_kind.py | 88 +++++++++++++++++-- 2 files changed, 88 insertions(+), 11 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index a2555b9ea763..6b58d67b8ead 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -156,12 +156,10 @@ bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = scope->buffer_writers; for (const BufferRegion& write_region : block->writes) { - ICHECK(buffer_writers.count(write_region->buffer)) - << "InternalError: buffer \"" << write_region->buffer->name - << "\" does not exist in the current scope, when querying block:\n" - << GetRef(block); - if (buffer_writers.at(write_region->buffer).size() != 1) { - return false; + if (buffer_writers.count(write_region->buffer)) { + if (buffer_writers.at(write_region->buffer).size() != 1) { + return false; + } } } return true; @@ -395,6 +393,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt for (const StmtSRef& block_sref : child_block_srefs) { // Local complete: complete block under the subtree. // Local reduction: reduction block under the subtree. + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (!IsCompleteBlock(self, block_sref, block_sref) && !IsReductionBlock(self, block_sref, block_sref)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 2cf043ee4765..8bf3dfe59b8e 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -330,6 +330,72 @@ def decomposed_gemm_after_vectorize( C[vi, vj] = local[vi, vj] +@T.prim_func +def decomposed_gemm_parallelize_init( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + local = T.alloc_buffer([16, 16], dtype="float32") + for i, j in T.grid(4, 4): + for ii in T.serial(4): + for jj in T.vectorized(4): + with T.block("init"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + T.reads() + T.writes(local[vi, vj]) + local[vi, vj] = 0 + for k, ii, jj in T.grid(16, 4, 4): + with T.block("update"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + vk = T.axis.reduce(16, k) + T.reads(local[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(local[vi, vj]) + local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk] + for ii, jj in T.grid(4, 4): + with T.block("C"): + vi = T.axis.spatial(16, i * 4 + ii) + vj = T.axis.spatial(16, j * 4 + jj) + T.reads(local[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = local[vi, vj] + + +@T.prim_func +def scatter_compute(A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"]): + for i in T.grid(8): + with T.block("first_half"): + vi = T.axis.spatial(16, 8 + i) + B[vi] = A[vi - 8] + + for i in T.grid(8): + with T.block("last_half"): + vi = T.axis.spatial(16, i) + B[vi] = A[vi + 8] + + +@T.prim_func +def scatter_compute_parallelize( + A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"] +) -> None: + # body + # with T.block("root") + for i in T.parallel(8): + with T.block("first_half"): + vi = T.axis.spatial(16, 8 + i) + T.reads(A[vi - 8]) + T.writes(B[vi]) + B[vi] = A[vi - 8] + for i in T.parallel(8): + with T.block("last_half"): + vi = T.axis.spatial(16, i) + T.reads(A[vi + 8]) + T.writes(B[vi]) + B[vi] = A[vi + 8] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -468,18 +534,30 @@ def test_vectorize_after_decompose(): verify_trace_roundtrip(s, mod=decomposed_gemm) -def test_compact_data_flow_local_complete_reduction(): +def test_vectorize_init(): s = tir.Schedule(decomposed_gemm, debug_mask="all") init_blk = s.get_block("init") upd_blk = s.get_block("update") - ii_0, jj_0 = s.get_loops(init_blk) - k_1, ii_1, jj_1 = s.get_child_blocks(upd_blk) + _, _, ii_0, jj_0 = s.get_loops(init_blk) + _, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk) s.vectorize(jj_0) - s.bind(jj_1, "threadIdx.x") print(s.mod["main"].script()) - tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_double_bound) + tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init) verify_trace_roundtrip(s, mod=decomposed_gemm) +def test_scatter_parallelize(): + s = tir.Schedule(scatter_compute, debug_mask="all") + first = s.get_block("first_half") + last = s.get_block("last_half") + (i_0,) = s.get_loops(first) + (i_1,) = s.get_loops(last) + s.parallel(i_0) + s.parallel(i_1) + print(s.mod["main"].script()) + tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize) + verify_trace_roundtrip(s, mod=scatter_compute) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From a8de362cb93ec321439e1920bda8ddae7c45970a Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 03:29:51 -0700 Subject: [PATCH 3/7] upd --- src/tir/schedule/analysis/analysis.cc | 40 ++++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6b58d67b8ead..e987a06cbbc2 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -145,16 +145,32 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { /*! * \brief Check the dominant property of a block: - * the block is the only writer of its output, dominating the reader of its output buffers - * \param scope The block-scope of the block to be checked - * \param block_sref The block whose dominant property is to be checked - * \return A boolean indicating if the block is a dominant block + * the block is the only writer of its output, dominating the reader of its output buffers under the + * given root scope. + * \param self The schedule state. + * \param scope_root_sref The StmtSRef corresponding to the root scope. + * \param block_sref The block whose dominant property is to be checked. + * \return A boolean indicating if the block is a dominant block. */ -bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { +bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, + const StmtSRef& block_sref) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); + if (maybe_root_block) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + buffer_writers = scope->buffer_writers; + } else { + // Collect all child blocks of root sub-tree, and merge their buffer writers. + Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); + for (const StmtSRef& child_block_sref : child_block_srefs) { + BlockScope child_scope = self->GetBlockScope(child_block_sref); + for (const auto& it : child_scope->buffer_writers) { + buffer_writers.insert(it); + } + } + } // Check whether the input block is the only writer of its outputs const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffer_writers = - scope->buffer_writers; for (const BufferRegion& write_region : block->writes) { if (buffer_writers.count(write_region->buffer)) { if (buffer_writers.at(write_region->buffer).size() != 1) { @@ -176,7 +192,6 @@ bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) { */ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); // Cond 1. All block vars are data parallel const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { @@ -186,7 +201,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block } // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers - if (!IsDominantBlock(scope, block_sref)) { + if (!IsDominantBlock(self, scope_root_sref, block_sref)) { return 2; } // Cond 3. No overlap between the buffers the block reads and writes @@ -258,7 +273,6 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, */ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { - BlockScope scope = self->GetBlockScope(scope_root_sref); const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { @@ -275,7 +289,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc } // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. - if (!IsDominantBlock(scope, block_sref)) { + if (!IsDominantBlock(self, scope_root_sref, block_sref)) { return 4; } // Cond 5. The reduction block vars are not used to index the output buffers. @@ -394,8 +408,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt // Local complete: complete block under the subtree. // Local reduction: reduction block under the subtree. const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - if (!IsCompleteBlock(self, block_sref, block_sref) && - !IsReductionBlock(self, block_sref, block_sref)) { + if (!IsCompleteBlock(self, block_sref, subtree_root) && + !IsReductionBlock(self, block_sref, subtree_root)) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), GetRef(block)); From 1be6edc52f5ae08ff54e42d426c16c768de2f169 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 03:31:42 -0700 Subject: [PATCH 4/7] remove redundant print --- tests/python/unittest/test_tir_schedule_for_kind.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 8bf3dfe59b8e..ac8288901688 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -541,7 +541,6 @@ def test_vectorize_init(): _, _, ii_0, jj_0 = s.get_loops(init_blk) _, _, k_1, ii_1, jj_1 = s.get_loops(upd_blk) s.vectorize(jj_0) - print(s.mod["main"].script()) tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_parallelize_init) verify_trace_roundtrip(s, mod=decomposed_gemm) @@ -554,7 +553,6 @@ def test_scatter_parallelize(): (i_1,) = s.get_loops(last) s.parallel(i_0) s.parallel(i_1) - print(s.mod["main"].script()) tvm.ir.assert_structural_equal(s.mod["main"], scatter_compute_parallelize) verify_trace_roundtrip(s, mod=scatter_compute) From 64e6246a6733f3e9c884986be3a02c8f9db7cf86 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 18:14:09 -0700 Subject: [PATCH 5/7] upd --- src/tir/schedule/analysis/analysis.cc | 67 +++++++++++++++++++++------ 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e987a06cbbc2..388413d73b5f 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -143,6 +143,19 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { return std::move(visitor.result); } +/*! + * \brief Check whether the given sref_a is higher than or equal to sref_b. + */ +void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { + const StmtSRefNode* p = sref_b.get(); + for (; p != nullptr; p = p->parent) { + if (p == sref_a.get()) { + return; + } + } + CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b; +} + /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers under the @@ -155,6 +168,7 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, const StmtSRef& block_sref) { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + CheckSRefHigherOrEqual(scope_root_sref, block_sref); const BlockNode* maybe_root_block = scope_root_sref->StmtAs(); if (maybe_root_block) { BlockScope scope = self->GetBlockScope(scope_root_sref); @@ -220,14 +234,26 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block static const char* kCompleteBlockDefinition = R"(Definition of a complete block: 1) All block vars are data parallel -2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers under the given scope. +2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers 3) No overlap between the buffers the block reads and writes)"; static const char* kReductionBlockDefinition = R"(Definition of a reduction block: 1) The block has the `init` statement 2) All the block bindings are quasi-affine expressions 3) All block vars are either data parallel block vars or reduction block vars -4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers under the given scope. +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +5) The reduction block vars are not used to index the output buffers)"; + +static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block: +1) All block vars are data parallel +2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree +3) No overlap between the buffers the block reads and writes)"; + +static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree 5) The reduction block vars are not used to index the output buffers)"; bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, @@ -378,22 +404,32 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { class NotCompactDataFlowError : public ScheduleError { public: - explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block, + int local_complete_block_code, int local_reduction_block_code) : mod_(std::move(mod)), subtree_root_(std::move(subtree_root)), - violate_block_(std::move(violate_block)) { + violate_block_(std::move(violate_block)), + local_complete_block_code_(local_complete_block_code), + local_reduction_block_code_(local_reduction_block_code) { ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); } String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " "because some of its child block on SRef tree is neither a local complete block nor a " - "local " - "reduction block"; + "local reduction block."; } String DetailRenderTemplate() const final { - return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " - "its child block {1} on SRef tree is neither a local complete block nor a local " - "reduction block"; + std::ostringstream os; + os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because " + "its child block {1} on SRef tree is neither a local complete block nor a local " + "reduction block.\n"; + os << "It violates condition #" << local_complete_block_code_ + << " as a local complete block.\n"; + os << kLocalCompleteBlockDefinition << "\n"; + os << "It violates condition #" << local_reduction_block_code_ + << " as a local reduction block.\n"; + os << kLocalReductionBlockDefinition << "\n"; + return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } @@ -401,18 +437,19 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt IRModule mod_; Stmt subtree_root_; Block violate_block_; + int local_complete_block_code_; + int local_reduction_block_code_; }; Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); for (const StmtSRef& block_sref : child_block_srefs) { - // Local complete: complete block under the subtree. - // Local reduction: reduction block under the subtree. - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - if (!IsCompleteBlock(self, block_sref, subtree_root) && - !IsReductionBlock(self, block_sref, subtree_root)) { + int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), + local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); + if (local_complete_block_code != 0 && local_reduction_block_code != 0) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), - GetRef(block)); + GetRef(block), local_complete_block_code, + local_reduction_block_code); } } } From 7109c10aa88828cf9b9c9badfcae601a57c73f5c Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 18:17:33 -0700 Subject: [PATCH 6/7] change the reads/writes region for argmin/val --- tests/python/unittest/test_te_create_primfunc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 48082c44a4ab..a65c5d8a0bd8 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -395,7 +395,7 @@ def tir_argmax_idx_val( for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) - T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k]) + T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.int32(-1) @@ -442,7 +442,7 @@ def tir_argmax_val_idx( for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) - T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k]) + T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.min_value("float32") From f870d1d3fd1bc24d7604e2d19c30f201b7155313 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 22 Mar 2022 18:48:13 -0700 Subject: [PATCH 7/7] fix wrong push --- tests/python/unittest/test_te_create_primfunc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index a65c5d8a0bd8..48082c44a4ab 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -395,7 +395,7 @@ def tir_argmax_idx_val( for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) - T.reads(val[i, k], idx[i, k]) + T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.int32(-1) @@ -442,7 +442,7 @@ def tir_argmax_val_idx( for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) - T.reads(val[i, k], idx[i, k]) + T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.min_value("float32")