From deccfdb0c45c0440d9a2c86c9efa19058f2049ec Mon Sep 17 00:00:00 2001 From: Karl Koscher Date: Wed, 29 Mar 2023 13:59:45 -0700 Subject: [PATCH] Remove special-casing of T.address_of in the storage rewrite pass The use case for this no longer exists, as loads and stores are no longer tracked separately. This special-casing can also introduce bugs when calling external microkernels. Co-authored-by: Eric Lunderberg --- src/tir/transforms/storage_rewrite.cc | 11 ---- .../test_tir_transform_storage_rewrite.py | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index bb76617d8ac5..240b16aa5b1f 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -154,17 +154,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::address_of())) { - const BufferLoadNode* load = op->args[0].as(); - for (const auto& index : load->indices) { - this->VisitExpr(index); - } - } else { - StmtExprVisitor::VisitExpr_(op); - } - } - void VisitExpr_(const VarNode* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index bcf498659902..cff76766b366 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -258,6 +258,59 @@ def verify(v): assert num_alloc[0] == 1 +def test_address_of(): + # In this test, the storage rewrite pass is allowed to + # combine buffers B and D, but not C + @T.prim_func + def before(A: T.Buffer(8, "float32"), E: T.Buffer(8, "float32")): + B_data = T.allocate([8], "float32") + B = T.Buffer(8, data=B_data, align=32) + for i in range(8): + B[i] = ( + T.call_extern("deref", T.address_of(A[i]), dtype="float32") + + T.call_extern("deref", T.address_of(A[0]), dtype="float32") + + T.float32(1) + ) + C_data = T.allocate([8], "float32") + C = T.Buffer(8, data=C_data, align=32) + for i in range(8): + C[i] = ( + T.call_extern("deref", T.address_of(B[i]), dtype="float32") + + T.call_extern("deref", T.address_of(B[0]), dtype="float32") + + T.float32(2) + ) + D_data = T.allocate([8], "float32") + D = T.Buffer(8, data=D_data, align=32) + for i in range(8): + D[i] = ( + T.call_extern("deref", T.address_of(C[i]), dtype="float32") + + T.call_extern("deref", T.address_of(C[0]), dtype="float32") + + T.float32(2) + ) + for i in range(8): + E[i] = ( + T.call_extern("deref", T.address_of(D[i]), dtype="float32") + + T.call_extern("deref", T.address_of(D[0]), dtype="float32") + + T.float32(3) + ) + + def verify(n): + if isinstance(n, tvm.tir.Allocate): + total_alloc[0] += n.extents[0].value + + total_alloc = [0] + mod = tvm.IRModule.from_expr(before) + mod.show() + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify) + assert total_alloc[0] == 24 + + total_alloc[0] = 0 + mod = tvm.tir.transform.StorageRewrite()(mod) + mod.show() + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify) + assert total_alloc[0] == 16 + + def test_storage_share_gpu(): m = te.var("m") A = [te.placeholder((m), name="A")]