Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
}
}

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
for (const auto& index : load->indices) {
this->VisitExpr(index);
}
} else {
StmtExprVisitor::VisitExpr_(op);
}
}

void VisitExpr_(const VarNode* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,59 @@ def verify(v):
assert num_alloc[0] == 1


def test_address_of():
# In this test, the storage rewrite pass is allowed to
# combine buffers B and D, but not C
@T.prim_func
def before(A: T.Buffer(8, "float32"), E: T.Buffer(8, "float32")):
B_data = T.allocate([8], "float32")
B = T.Buffer(8, data=B_data, align=32)
for i in range(8):
B[i] = (
T.call_extern("deref", T.address_of(A[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(A[0]), dtype="float32")
+ T.float32(1)
)
C_data = T.allocate([8], "float32")
C = T.Buffer(8, data=C_data, align=32)
for i in range(8):
C[i] = (
T.call_extern("deref", T.address_of(B[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(B[0]), dtype="float32")
+ T.float32(2)
)
D_data = T.allocate([8], "float32")
D = T.Buffer(8, data=D_data, align=32)
for i in range(8):
D[i] = (
T.call_extern("deref", T.address_of(C[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(C[0]), dtype="float32")
+ T.float32(2)
)
for i in range(8):
E[i] = (
T.call_extern("deref", T.address_of(D[i]), dtype="float32")
+ T.call_extern("deref", T.address_of(D[0]), dtype="float32")
+ T.float32(3)
)

def verify(n):
if isinstance(n, tvm.tir.Allocate):
total_alloc[0] += n.extents[0].value

total_alloc = [0]
mod = tvm.IRModule.from_expr(before)
mod.show()
tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify)
assert total_alloc[0] == 24

total_alloc[0] = 0
mod = tvm.tir.transform.StorageRewrite()(mod)
mod.show()
tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify)
assert total_alloc[0] == 16


def test_storage_share_gpu():
m = te.var("m")
A = [te.placeholder((m), name="A")]
Expand Down