Skip to content

Commit 916542e

Browse files
authored
[TVMScript] Ensure completed root block has no read/write (#15249)
Prior to this PR, the root block of a parsed TIR TVMScript is possible to have non-empty read/write regions, which conflicts with the design of root blocks in TIR. This PR updates the script completion pass and ensures that the root block will no longer have read/write region.
1 parent 7489ce2 commit 916542e

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

src/tir/ir/script/script_complete.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,18 @@ namespace tir {
3636
class ScriptCompleter : public StmtMutator {
3737
public:
3838
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
39-
/*! \brief Whether the stmt contains at least one block. */
40-
bool contains_block = false;
4139

4240
private:
4341
Map<Var, Buffer>* buffer_var_map_;
44-
Stmt VisitStmt_(const BlockRealizeNode* op) override {
45-
contains_block = true;
42+
Stmt VisitStmt_(const BlockRealizeNode* op) final {
4643
for (const PrimExpr& value : op->iter_values) {
4744
CHECK(value.dtype().is_int())
4845
<< "BlockRealize iter_value expected a IntImm, but got " << value.dtype();
4946
}
5047
return StmtMutator::VisitStmt_(op);
5148
}
5249

53-
Stmt VisitStmt_(const BlockNode* op) override {
50+
Stmt VisitStmt_(const BlockNode* op) final {
5451
// Buffers allocated in the block can be accessed by its body.
5552
for (const auto& alloc_buffer : op->alloc_buffers) {
5653
buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
@@ -59,7 +56,12 @@ class ScriptCompleter : public StmtMutator {
5956
const Buffer& target_buffer = match_buffer->buffer;
6057
buffer_var_map_->Set(target_buffer->data, target_buffer);
6158
}
59+
60+
bool is_root_block = this->is_root_block_;
61+
this->is_root_block_ = false;
6262
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
63+
this->is_root_block_ = is_root_block;
64+
6365
// Remove buffers allocated inside block to detect its access region
6466
for (const auto& alloc_buffer : op->alloc_buffers) {
6567
buffer_var_map_->erase(alloc_buffer->data);
@@ -85,15 +87,19 @@ class ScriptCompleter : public StmtMutator {
8587
<< "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or "
8688
"direct access by buffer data. Please annotation the access region manually";
8789
auto n = CopyOnWrite(block.operator->());
88-
if (mask & 1) n->reads = reads;
89-
if (mask & 2) n->writes = writes;
90+
if (!is_root_block) {
91+
if (mask & 1) n->reads = reads;
92+
if (mask & 2) n->writes = writes;
93+
}
9094
n->annotations = op->annotations;
9195
n->annotations.erase(attr::script_parsing_detect_access);
9296
return Block(n);
9397
} else {
9498
return std::move(block);
9599
}
96100
}
101+
102+
bool is_root_block_ = true;
97103
};
98104

99105
PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {

tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "f
235235
@T.prim_func
236236
def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "float32"), C: T.Buffer((4, 4, 32), "float32"), T_layer_norm: T.Buffer((1, 4, 4, 32), "float32")):
237237
with T.block("root"):
238-
T.reads(A[0, 0:4, 0:4, 0:32], B[0:4, 0:4, 0:32], C[0:4, 0:4, 0:32])
239-
T.writes(T_layer_norm[0, 0:4, 0:4, 0:32])
240238
A_red_temp_v0 = T.alloc_buffer((1,))
241239
A_red_temp_v1 = T.alloc_buffer((1,))
242240
for ax0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}):

tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,6 @@ def main(a: T.handle, b: T.handle):
315315
A = T.match_buffer(a, (1024, 1024))
316316
B = T.match_buffer(b, (1024, 1024))
317317
with T.block("root"):
318-
T.reads(A[0:1024, 0:1024])
319-
T.writes(B[0:1024, 0:1024])
320318
T.block_attr({"warp_execution": True})
321319
for bx in T.thread_binding(8, thread="blockIdx.x"):
322320
for by in T.thread_binding(8, thread="blockIdx.y"):
@@ -583,8 +581,6 @@ class TransformedWmmaToGlobal:
583581
@T.prim_func
584582
def main(C: T.Buffer((1024, 1024), "float32")):
585583
with T.block("root"):
586-
T.reads()
587-
T.writes(C[0:1024, 0:1024])
588584
T.block_attr({"warp_execution": True})
589585
for bx in T.thread_binding(8, thread="blockIdx.x"):
590586
for by in T.thread_binding(8, thread="blockIdx.y"):
@@ -785,8 +781,6 @@ def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024, 1024), "float32"))
785781
s1 = T.int32()
786782
# body
787783
with T.block("root"):
788-
T.reads(A[0:1024])
789-
T.writes(C[0:1024, 0:1024])
790784
T.block_attr({"warp_execution": True})
791785
for bx in T.thread_binding(8, thread="blockIdx.x"):
792786
for by in T.thread_binding(8, thread="blockIdx.y"):
@@ -1009,8 +1003,6 @@ class TransformedMmaToGlobal:
10091003
@T.prim_func
10101004
def main(C: T.Buffer((1024, 1024), "float32")):
10111005
with T.block("root"):
1012-
T.reads()
1013-
T.writes(C[0:1024, 0:1024])
10141006
T.block_attr({"warp_execution": T.bool(True)})
10151007
for bx in T.thread_binding(8, thread="blockIdx.x"):
10161008
for by in T.thread_binding(8, thread="blockIdx.y"):

tests/python/unittest/test_tvmscript_complete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def test_complete_matmul_original():
153153
def _check_elementwise(func):
154154
A, B, C = [func.buffer_map[x] for x in func.params]
155155

156+
root_block = func.body.block
157+
assert len(root_block.reads) == 0
158+
assert len(root_block.writes) == 0
159+
156160
block1 = func.body.block.body[0].body.body.block
157161
assert isinstance(block1, tvm.tir.Block)
158162
vi, vj = [x.var for x in block1.iter_vars]

0 commit comments

Comments
 (0)