Skip to content

Commit 1d98634

Browse files
authored
[TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
* Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer * add comment
1 parent b7015bb commit 1d98634

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/tir/transforms/plan_update_buffer_allocation_location.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor {
6161
}
6262

6363
private:
64+
bool find(const Buffer& buf) {
65+
return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) !=
66+
buffer_alloc_recorder_.end();
67+
}
68+
6469
void VisitStmt_(const BlockNode* op) final {
6570
for (const Buffer& buffer : op->alloc_buffers) {
6671
buffer_alloc_recorder_.push_back(buffer);
6772
}
73+
// Also visit match_buffers to collect constant buffers associated with AllocateConst nodes.
74+
// These buffers only appear in read and match_buffer regions.
75+
for (const auto& region : op->match_buffers) {
76+
if (!find(region->source->buffer)) {
77+
buffer_alloc_recorder_.push_back(region->source->buffer);
78+
}
79+
}
80+
6881
StmtExprVisitor::VisitStmt_(op);
6982
}
7083

7184
void VisitExpr_(const BufferLoadNode* op) final {
72-
if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
73-
buffer_alloc_recorder_.end()) {
85+
if (!find(op->buffer)) {
7486
buffer_alloc_recorder_.push_back(op->buffer);
7587
}
7688
StmtExprVisitor::VisitExpr_(op);
7789
}
7890

7991
void VisitStmt_(const BufferStoreNode* op) final {
80-
if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
81-
buffer_alloc_recorder_.end()) {
92+
if (!find(op->buffer)) {
8293
buffer_alloc_recorder_.push_back(op->buffer);
8394
}
8495
StmtExprVisitor::VisitStmt_(op);

0 commit comments

Comments
 (0)