Skip to content

Commit 657880c

Browse files
authored
[Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive (#16660)
* [Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications. This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions. * Improve PrimFunc readability * Remove redundant `T.reads()`
1 parent e005f85 commit 657880c

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
483483
seq.insert(seq.begin() + pos, stage);
484484
body = SeqStmt(seq);
485485
} else if (pos == 0) {
486-
body = SeqStmt({stage, stmt});
486+
body = SeqStmt({stage, body});
487487
} else if (pos == 1) {
488-
body = SeqStmt({stmt, stage});
488+
body = SeqStmt({body, stage});
489489
} else {
490490
LOG(FATAL) << "Cannot insert at position " << pos
491491
<< ". When inserting adjacent to non-SeqStmt, "

tests/python/tir-schedule/test_tir_schedule_cache_read_write.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,46 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name):
13791379
sch.cache_read(block_b, 0, "test_scope")
13801380

13811381

1382+
def test_cache_read_allocate_const():
1383+
@T.prim_func
1384+
def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
1385+
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1386+
B_buf = T.decl_buffer((8), dtype="float32", data=B)
1387+
for i in range(8):
1388+
with T.block("C"):
1389+
vi = T.axis.spatial(8, i)
1390+
C[vi] = A[vi] + B_buf[vi]
1391+
1392+
@T.prim_func
1393+
def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
1394+
B_buf_global = T.alloc_buffer((8), dtype="float32")
1395+
A_global = T.alloc_buffer((8), dtype="float32")
1396+
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1397+
B_buf = T.decl_buffer((8), data=B)
1398+
for ax0 in range(8):
1399+
with T.block("A_global"):
1400+
v0 = T.axis.spatial(8, ax0)
1401+
A_global[v0] = A[v0]
1402+
for ax0 in range(8):
1403+
with T.block("B_buf_global"):
1404+
v0 = T.axis.spatial(8, ax0)
1405+
B_buf_global[v0] = B_buf[v0]
1406+
for i in range(8):
1407+
with T.block("C"):
1408+
vi = T.axis.spatial(8, i)
1409+
C[vi] = A_global[vi] + B_buf_global[vi]
1410+
1411+
sch = tir.Schedule(before)
1412+
block_c = sch.get_block("C")
1413+
sch.cache_read(block_c, 1, "global")
1414+
sch.cache_read(block_c, 0, "global")
1415+
1416+
after = sch.mod["main"]
1417+
1418+
assert_structural_equal_ignore_global_symbol(expected, after)
1419+
verify_trace_roundtrip(sch=sch, mod=before)
1420+
1421+
13821422
def test_inplace_cache_read():
13831423
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
13841424
block = sch.get_block("copy_in")

0 commit comments

Comments
 (0)