Skip to content

Commit 9dad95d

Browse files
author
Siyuan Feng
authored
[TIR] Fix block access region detection for nested let bindings (#18069)
Recursively substitute let bindings in buffer indices until no more substitutions are possible. Add test case to verify handling of nested let bindings.
1 parent 6c540e0 commit 9dad95d

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/tir/analysis/block_access_region_detector.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,12 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef
153153

154154
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
155155
std::vector<arith::IntSet> relaxed_region;
156-
for (const PrimExpr& index : op->indices) {
156+
for (PrimExpr index : op->indices) {
157157
PrimExpr remapped_index = Substitute(index, let_bindings_);
158+
while (!remapped_index.same_as(index)) {
159+
index = remapped_index;
160+
remapped_index = Substitute(index, let_bindings_);
161+
}
158162
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_));
159163
}
160164
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
@@ -236,8 +240,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
236240

237241
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
238242
std::vector<arith::IntSet> relaxed_region;
239-
for (const PrimExpr& index : op->indices) {
243+
for (PrimExpr index : op->indices) {
240244
PrimExpr remapped_index = Substitute(index, let_bindings_);
245+
while (!remapped_index.same_as(index)) {
246+
index = remapped_index;
247+
remapped_index = Substitute(index, let_bindings_);
248+
}
241249
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_));
242250
}
243251
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);

tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,5 +385,31 @@ def func(
385385
tvm.ir.assert_structural_equal(block.writes, ret[1])
386386

387387

388+
def test_buffer_access_with_nested_let_binding():
389+
@T.prim_func
390+
def func(
391+
A: T.Buffer((16, 16), "float32"),
392+
B: T.Buffer((16, 16), "float32"),
393+
C: T.Buffer((16, 16), "float32"),
394+
):
395+
for i, s in T.grid(16, 16):
396+
with T.block("copy"):
397+
vi, vs = T.axis.remap("SS", [i, s])
398+
T.reads(A[vi, vs], B[vi, vs])
399+
T.writes(C[vi, vs])
400+
vi1: T.int32 = vi
401+
vi2: T.int32 = vi1
402+
vs1: T.int32 = vs
403+
vs2: T.int32 = vs1
404+
vs3: T.int32 = vs2
405+
C[vi, vs1] = A[vi1, vs2] + B[vi2, vs3]
406+
407+
block = func.body.block.body.body.body.block
408+
buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
409+
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
410+
tvm.ir.assert_structural_equal(block.reads, ret[0])
411+
tvm.ir.assert_structural_equal(block.writes, ret[1])
412+
413+
388414
if __name__ == "__main__":
389415
tvm.testing.main()

0 commit comments

Comments
 (0)