Skip to content

Commit a23bff7

Browse files
LunderbergLucien0
authored andcommitted
[TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (apache#10998)
* [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator Prior to this commit, the BufferAllocationLocator mutator used in the PlanAndUpdateBufferAllocationLocation pass would erroneously insert an entry to `BlockNode::alloc_buffers` for buffers allocated using `Allocate` or `AllocateConst` nodes. This error was introduced in apache#9727, which deprecated `Load` and `Store` nodes, replacing them with `BufferLoad` and `BufferStore` nodes. As a result, BufferAllocationLocator identified these as buffers whose allocations should be moved to inner loops, rather than as unmanaged allocations that should be ignored. This commit restores the earlier behavior by only operating on buffer allocations in `BlockNode::alloc_buffers`, and explicitly ignoring any buffers whose allocation is done with `Allocate` or `AllocateConst`. * Only inject opaque block if managed buffers exist. Previously, all buffers found were managed buffers, so this check wasn't needed.
1 parent afee8b3 commit a23bff7

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

src/tir/transforms/plan_update_buffer_allocation_location.cc

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,21 @@ class BufferAllocationLocator : public StmtExprMutator {
6161
for (const Buffer& buf : it->second) {
6262
buffer_data_to_buffer_.Set(buf->data, buf);
6363
}
64-
Stmt stmt = StmtMutator::VisitStmt_(op);
65-
op = stmt.as<ForNode>();
66-
ICHECK(op != nullptr);
64+
auto node = Downcast<For>(StmtMutator::VisitStmt_(op));
65+
66+
Array<Buffer> new_block_alloc_bufs;
6767
for (const Buffer& buf : it->second) {
68-
buffer_data_to_buffer_.erase(buf->data);
68+
if (!unmanaged_allocations_.count(buf->data.get())) {
69+
buffer_data_to_buffer_.erase(buf->data);
70+
new_block_alloc_bufs.push_back(buf);
71+
}
6972
}
70-
Stmt body = InjectOpaqueBlock(op->body, it->second);
71-
ObjectPtr<ForNode> n = CopyOnWrite(op);
72-
n->body = std::move(body);
73-
return Stmt(n);
73+
74+
if (new_block_alloc_bufs.size()) {
75+
node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs);
76+
}
77+
78+
return std::move(node);
7479
}
7580

7681
Stmt VisitStmt_(const BlockNode* op) final {
@@ -114,6 +119,16 @@ class BufferAllocationLocator : public StmtExprMutator {
114119
return Stmt(n);
115120
}
116121

122+
Stmt VisitStmt_(const AllocateNode* op) final {
123+
unmanaged_allocations_.insert(op->buffer_var.get());
124+
return StmtExprMutator::VisitStmt_(op);
125+
}
126+
127+
Stmt VisitStmt_(const AllocateConstNode* op) final {
128+
unmanaged_allocations_.insert(op->buffer_var.get());
129+
return StmtExprMutator::VisitStmt_(op);
130+
}
131+
117132
Stmt VisitStmt_(const BufferRealizeNode* op) final {
118133
ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
119134
throw;
@@ -151,6 +166,8 @@ class BufferAllocationLocator : public StmtExprMutator {
151166
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
152167
/*! \brief The buffer already allocated during recursive visiting. */
153168
Map<Var, Buffer> buffer_data_to_buffer_;
169+
/*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */
170+
std::unordered_set<const VarNode*> unmanaged_allocations_;
154171
};
155172

156173
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {

tests/python/unittest/test_tir_transform_extract_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def _visit(stmt):
5959
for n, f in mod.functions.items():
6060
tvm.tir.stmt_functor.post_order_visit(f.body, _visit)
6161

62+
tvm.lower(mod)
63+
6264

6365
if __name__ == "__main__":
6466
test_const_extraction()

0 commit comments

Comments
 (0)