@@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor {
6161 }
6262
6363 private:
64+ bool find (const Buffer& buf) {
65+ return std::find (buffer_alloc_recorder_.begin (), buffer_alloc_recorder_.end (), buf) !=
66+ buffer_alloc_recorder_.end ();
67+ }
68+
6469 void VisitStmt_ (const BlockNode* op) final {
6570 for (const Buffer& buffer : op->alloc_buffers ) {
6671 buffer_alloc_recorder_.push_back (buffer);
6772 }
73+ // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes.
74+ // These buffers only appear in read and match_buffer regions.
75+ for (const auto & region : op->match_buffers ) {
76+ if (!find (region->source ->buffer )) {
77+ buffer_alloc_recorder_.push_back (region->source ->buffer );
78+ }
79+ }
80+
6881 StmtExprVisitor::VisitStmt_ (op);
6982 }
7083
7184 void VisitExpr_ (const BufferLoadNode* op) final {
72- if (std::find (buffer_alloc_recorder_.begin (), buffer_alloc_recorder_.end (), op->buffer ) ==
73- buffer_alloc_recorder_.end ()) {
85+ if (!find (op->buffer )) {
7486 buffer_alloc_recorder_.push_back (op->buffer );
7587 }
7688 StmtExprVisitor::VisitExpr_ (op);
7789 }
7890
7991 void VisitStmt_ (const BufferStoreNode* op) final {
80- if (std::find (buffer_alloc_recorder_.begin (), buffer_alloc_recorder_.end (), op->buffer ) ==
81- buffer_alloc_recorder_.end ()) {
92+ if (!find (op->buffer )) {
8293 buffer_alloc_recorder_.push_back (op->buffer );
8394 }
8495 StmtExprVisitor::VisitStmt_ (op);
0 commit comments