Skip to content

Commit ab62f4a

Browse files
[TIR][Transform] Keep the allocate buffers order after update buffer allocation location
1 parent 6574e16 commit ab62f4a

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed

src/tir/transforms/plan_update_buffer_allocation_location.cc

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,35 @@ class CollectUnmanagedAllocations : public StmtExprVisitor {
4848
std::unordered_set<const VarNode*> unmanaged_allocations;
4949
};
5050

51+
/*! \brief Collect the allocate buffer order. */
52+
class BufferAllocateOrderCollector : public StmtExprVisitor {
53+
public:
54+
static Array<Buffer> Collect(const PrimFunc& func) {
55+
BufferAllocateOrderCollector collector;
56+
for (const auto& kv : func->buffer_map) {
57+
collector.buffer_alloc_recorder_.push_back(kv.second);
58+
}
59+
collector(func->body);
60+
return std::move(collector.buffer_alloc_recorder_);
61+
}
62+
63+
private:
64+
void VisitStmt_(const BlockNode* op) final {
65+
for (const Buffer& buffer : op->alloc_buffers) {
66+
buffer_alloc_recorder_.push_back(buffer);
67+
}
68+
StmtExprVisitor::VisitStmt_(op);
69+
}
70+
71+
/*! \brief The buffer allocated order recorder. */
72+
Array<Buffer> buffer_alloc_recorder_;
73+
};
74+
5175
class BufferAllocationLocator : public StmtExprMutator {
5276
public:
5377
explicit BufferAllocationLocator(const PrimFunc& func) {
5478
Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
79+
Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func);
5580
std::unordered_set<const VarNode*> arg_buffer_vars;
5681
CollectUnmanagedAllocations collector;
5782
collector(func->body);
@@ -63,16 +88,18 @@ class BufferAllocationLocator : public StmtExprMutator {
6388
buffer_data_to_buffer_.Set(buffer->data, buffer);
6489
}
6590
// create buffers to be allocated at each stmts
66-
for (const auto& kv : buffer_lca) {
67-
const Buffer& buffer = kv.first;
68-
const StmtNode* stmt = kv.second.get();
69-
if (arg_buffer_vars.count(buffer->data.get())) {
70-
continue;
71-
}
72-
if (!unmanaged_allocations_.count(buffer->data.get())) {
73-
alloc_buffers_[stmt].push_back(buffer);
91+
for (const auto& buffer : buffer_alloc_recorder) {
92+
auto it = buffer_lca.find(buffer);
93+
if (it != buffer_lca.end()) {
94+
const StmtNode* stmt = (*it).second.get();
95+
if (arg_buffer_vars.count(buffer->data.get())) {
96+
continue;
97+
}
98+
if (!unmanaged_allocations_.count(buffer->data.get())) {
99+
alloc_buffers_[stmt].push_back(buffer);
100+
}
101+
buffer_data_to_buffer_.Set(buffer->data, buffer);
74102
}
75-
buffer_data_to_buffer_.Set(buffer->data, buffer);
76103
}
77104
}
78105

tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,13 @@ def test_lower_te():
245245

246246
def test_loop_carried_dependency():
247247
"""The buffer allocation should be above opaque iter var's loop scopes
248-
such that buffer accesses with loop carried dependencies are covered."""
248+
such that buffer accesses with loop carried dependencies are covered,
249+
and the allocate buffer should keep the order."""
249250

250251
@T.prim_func
251252
def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]):
252253
C = T.alloc_buffer([8, 8, 8], dtype="int32")
254+
D = T.alloc_buffer([8, 8, 8], dtype="int32")
253255
for i in T.serial(8):
254256
for j in T.serial(8):
255257
for k in T.serial(8):
@@ -258,10 +260,16 @@ def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]):
258260
C[vi, vj, vk] = A[vi, vj, vk] + 1
259261
for k in T.serial(8):
260262
with T.block("b1"):
263+
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
264+
D[vi, vj, vk] = A[vi, vj, vk] + 2
265+
for k in T.serial(8):
266+
with T.block("b2"):
261267
vi, vk = T.axis.remap("SS", [i, k])
262268
vj = T.axis.opaque(8, j)
263-
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
264-
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
269+
B[vi, vj, vk] = (
270+
C[vi, vj, vk]
271+
+ T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0, dtype="int32")
272+
+ D[vi, vj, vk]
265273
)
266274

267275
@T.prim_func
@@ -271,17 +279,24 @@ def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> N
271279
T.reads(A[i, 0:8, 0:8])
272280
T.writes(B[i, 0:8, 0:8])
273281
C = T.alloc_buffer([8, 8, 8], dtype="int32")
282+
D = T.alloc_buffer([8, 8, 8], dtype="int32")
274283
for j in T.serial(8):
275284
for k in T.serial(8):
276285
with T.block("b0"):
277286
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
278287
C[vi, vj, vk] = A[vi, vj, vk] + 1
279288
for k in T.serial(8):
280289
with T.block("b1"):
290+
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
291+
D[vi, vj, vk] = A[vi, vj, vk] + 2
292+
for k in T.serial(8):
293+
with T.block("b2"):
281294
vi, vk = T.axis.remap("SS", [i, k])
282295
vj = T.axis.opaque(8, j)
283-
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
284-
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
296+
B[vi, vj, vk] = (
297+
C[vi, vj, vk]
298+
+ T.if_then_else(0 < vj, C[vi, vj - 1, vk], 0, dtype="int32")
299+
+ D[vi, vj, vk]
285300
)
286301

287302
_check(before, after)

0 commit comments

Comments
 (0)