Skip to content

Commit 877cc24

Browse files
committed
Updated BufferAllocationLocator to ignore aliases of arg buffers
1 parent 0773279 commit 877cc24

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/tir/transforms/plan_update_buffer_allocation_location.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,18 @@ class BufferAllocationLocator : public StmtExprMutator {
3535
public:
3636
explicit BufferAllocationLocator(const PrimFunc& func) {
3737
Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
38-
std::unordered_set<const BufferNode*> arg_buffers;
38+
39+
std::unordered_set<const VarNode*> arg_buffer_vars;
3940
for (const auto& kv : func->buffer_map) {
4041
const Buffer& buffer = kv.second;
41-
arg_buffers.emplace(buffer.get());
42+
arg_buffer_vars.emplace(buffer->data.get());
4243
buffer_data_to_buffer_.Set(buffer->data, buffer);
4344
}
4445
// create buffers to be allocated at each stmts
4546
for (const auto& kv : buffer_lca) {
4647
const Buffer& buffer = kv.first;
4748
const StmtNode* stmt = kv.second.get();
48-
if (arg_buffers.count(buffer.get())) {
49+
if (arg_buffer_vars.count(buffer->data.get())) {
4950
continue;
5051
}
5152
alloc_buffers_[stmt].push_back(buffer);

0 commit comments

Comments
 (0)