Skip to content

Commit 52847a0

Browse files
committed
[Bugfix][TIR] Fix duplicate AllocateConst in CacheReadWrite schedule primitive
When inserting a `cache_read` / `cache_write` stage, the `tir.AllocateConst` statement would be duplicated if its body was not a `tir.SeqStmt` node (e.g. `tir.For`), leading to compilation failures. This happened because `tir.AllocateConst` and `tir.DeclBuffer` statements are always re-attached to the statement's body after the `cache_read` / `cache_write` stage is inserted in it, but the stage was being appended to the whole statement (which already contains the `tir.AllocateConst`) and not just its body, causing duplications. This commit also adds a test where the first `cache_read` stage is inserted into a statement whose body is a `tir.For`, while the second stage is added to a body that is `tir.SeqStmt` to check for regressions.
1 parent c2c579b commit 52847a0

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
483483
seq.insert(seq.begin() + pos, stage);
484484
body = SeqStmt(seq);
485485
} else if (pos == 0) {
486-
body = SeqStmt({stage, stmt});
486+
body = SeqStmt({stage, body});
487487
} else if (pos == 1) {
488-
body = SeqStmt({stmt, stage});
488+
body = SeqStmt({body, stage});
489489
} else {
490490
LOG(FATAL) << "Cannot insert at position " << pos
491491
<< ". When inserting adjacent to non-SeqStmt, "

tests/python/tir-schedule/test_tir_schedule_cache_read_write.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,54 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name):
13791379
sch.cache_read(block_b, 0, "test_scope")
13801380

13811381

1382+
def test_cache_read_allocate_const():
1383+
@T.prim_func
1384+
def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
1385+
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1386+
B_buf = T.decl_buffer((8), dtype="float32", data=B)
1387+
for i in T.serial(128):
1388+
with T.block("C"):
1389+
vi = T.axis.remap("S", [i])
1390+
T.reads(A[vi], B_buf[vi])
1391+
T.writes(C[vi])
1392+
C[vi] = A[vi] + B_buf[vi]
1393+
1394+
@T.prim_func
1395+
def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")):
1396+
B_buf_global = T.alloc_buffer((8), dtype="float32")
1397+
A_global = T.alloc_buffer((8), dtype="float32")
1398+
B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1399+
B_buf = T.decl_buffer((8), data=B)
1400+
for ax0 in range(8):
1401+
with T.block("A_global"):
1402+
v0 = T.axis.spatial(8, ax0)
1403+
T.reads(A[v0])
1404+
T.writes(A_global[v0])
1405+
A_global[v0] = A[v0]
1406+
for ax0 in range(8):
1407+
with T.block("B_buf_global"):
1408+
v0 = T.axis.spatial(8, ax0)
1409+
T.reads(B_buf[v0])
1410+
T.writes(B_buf_global[v0])
1411+
B_buf_global[v0] = B_buf[v0]
1412+
for i in range(128):
1413+
with T.block("C"):
1414+
vi = T.axis.spatial(128, i)
1415+
T.reads(A_global[vi], B_buf_global[vi])
1416+
T.writes(C[vi])
1417+
C[vi] = A_global[vi] + B_buf_global[vi]
1418+
1419+
sch = tir.Schedule(before)
1420+
block_c = sch.get_block("C")
1421+
sch.cache_read(block_c, 1, "global")
1422+
sch.cache_read(block_c, 0, "global")
1423+
1424+
after = sch.mod["main"]
1425+
1426+
assert_structural_equal_ignore_global_symbol(expected, after)
1427+
verify_trace_roundtrip(sch=sch, mod=before)
1428+
1429+
13821430
def test_inplace_cache_read():
13831431
sch = tvm.tir.Schedule(inplace_func, debug_mask="all")
13841432
block = sch.get_block("copy_in")

0 commit comments

Comments
 (0)