Skip to content

Commit 534ff63

Browse files
author
qsqqsqqsq-intellif
committed
[Bugfix][TIR] Fix cache_read update buffer region
Prior to this commit, cache_read primitive may not update the block reads buffer region properly when there is a nested buffer access. This commit fix this bug and add a cache_read unit test.
1 parent 95ec38b commit 534ff63

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -958,9 +958,10 @@ class CacheReadRewriter : public StmtExprMutator {
958958
// Otherwise, update read regions and match_buffers
959959
// Only make this change if the block is one of the specified consumers.
960960
if (is_consumer) {
961-
Array<BufferRegion> reads = update_access_regions(block->reads);
962-
Array<MatchBufferRegion> match_buffers = update_match_buffers(block->match_buffers);
963-
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
961+
// Use the updated block stmt
962+
Array<BufferRegion> reads = update_access_regions(stmt->reads);
963+
Array<MatchBufferRegion> match_buffers = update_match_buffers(stmt->match_buffers);
964+
if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) {
964965
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
965966
n->reads = std::move(reads);
966967
n->match_buffers = std::move(match_buffers);

tests/python/tir-schedule/test_tir_schedule_cache_read_write.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,19 @@ def cache_read_nested_seq_target(
488488
C[vi, vj] = A_global[vi, vj] * T.float32(2)
489489

490490

491+
@T.prim_func
492+
def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
493+
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
494+
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
495+
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
496+
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
497+
with T.block("C"):
498+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
499+
T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
500+
T.writes(C[v_ax0, v_ax1])
501+
C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]
502+
503+
491504
########## Expected function after cache_read ##########
492505

493506

@@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None:
831844
data_io[v0] = data_io_global_1[v0]
832845

833846

847+
@T.prim_func
848+
def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
849+
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
850+
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
851+
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
852+
B_global = T.alloc_buffer((T.int64(1),), "int32")
853+
for ax0 in range(T.int64(1)):
854+
with T.block("B_global"):
855+
v0 = T.axis.spatial(T.int64(1), ax0)
856+
T.reads(B[v0])
857+
T.writes(B_global[v0])
858+
B_global[v0] = B[v0]
859+
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
860+
with T.block("C"):
861+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
862+
T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0])
863+
T.writes(C[v_ax0, v_ax1])
864+
C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1]
865+
866+
834867
########## Expected function after cache_write ##########
835868

836869

@@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name):
13581391
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)
13591392

13601393

1394+
def test_cache_read_nested_buffer_access(use_block_name):
1395+
sch = tir.Schedule(nested_buffer_access, debug_mask="all")
1396+
block_c = "C" if use_block_name else sch.get_block("C")
1397+
sch.cache_read(block_c, 1, "global")
1398+
assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, sch.mod["main"])
1399+
verify_trace_roundtrip(sch=sch, mod=nested_buffer_access)
1400+
1401+
13611402
def test_cache_read_fail_multi_producer(use_block_name):
13621403
sch = tir.Schedule(func_multi_producer, debug_mask="all")
13631404
block_b = "B" if use_block_name else sch.get_block("B")

0 commit comments

Comments
 (0)