diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 21456af1bdf4..92186a4ffea4 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -60,6 +60,44 @@ class ProducerToBufferTransformer : public StmtExprMutator { const std::unordered_map& tensor2buffers_; }; +/*! \brief The helper mutator to rewrite buffer and buffer var accessed by block body */ +class BufferSubstituter : public StmtExprMutator { + public: + explicit BufferSubstituter(const std::unordered_map& var_map, + const std::unordered_map& buffer_map) + : var_map_(var_map), buffer_map_(buffer_map) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(op); + if (it != var_map_.end()) { + return it->second; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_map_.find(load->buffer.get()); + if (it != buffer_map_.end()) { + return BufferLoad(it->second, load->indices, load->span); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_map_.find(store->buffer.get()); + if (it != buffer_map_.end()) { + return BufferStore(it->second, store->value, store->indices, store->span); + } + return store; + } + + private: + const std::unordered_map& var_map_; + const std::unordered_map& buffer_map_; +}; + /*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ @@ -364,6 +402,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) { // Step 1. Check all inputs are visited before and update var_map. std::unordered_map var_map; + std::unordered_map input_buffer_map; ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); for (size_t i = 0; i < extern_op->inputs.size(); ++i) { const Buffer& placeholder = extern_op->input_placeholders[i]; @@ -371,6 +410,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf auto it = info->tensor2buffers.find(input_tensor); ICHECK(it != info->tensor2buffers.end()); var_map[placeholder->data.get()] = it->second->data; + input_buffer_map[placeholder.get()] = it->second; } // Step 2. Update info with its output tensor and placeholder buffer. @@ -394,7 +434,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf writes.push_back(BufferRegion::FullRegion(buffer)); } - Stmt body = Substitute(extern_op->body, var_map); + BufferSubstituter substituter(var_map, input_buffer_map); + Stmt body = substituter(extern_op->body); // Step 4. Generate opaque block as body. return BlockRealize(/*iter_values=*/{}, diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index ea68015bc73b..7ead6e6ae6fb 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -128,7 +128,8 @@ class PrimFuncSpecializer : public StmtExprMutator { Array writes = op->writes.Map( std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); - if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) { + if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && + writes.same_as(op->writes)) { return GetRef(op); } else { ObjectPtr n = CopyOnWrite(op); @@ -238,12 +239,13 @@ class PrimFuncSpecializer : public StmtExprMutator { BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { auto it = buffer_map_.find(buffer_region->buffer); + const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer; Array region = buffer_region->region.Map( std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { return buffer_region; } else { - return BufferRegion(it->second, std::move(region)); + return BufferRegion(buffer, std::move(region)); } } diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index c13ede08313d..f78dc458d9d3 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -689,5 +689,35 @@ def test_argmax(): tvm.ir.assert_structural_equal(prim_func, argmax_expected) +def test_extern_with_explicit_buffer_access(): + def te_extern(): + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + P = te.placeholder((1,), name="P") + C = te.extern( + (128, 128), + [A, B, P], + lambda ins, outs: tvm.tir.call_extern( + "", "myfunc", ins[0].data, ins[1].data, outs[0].data, ins[2][0] + ), + name="C", + ) + return [A, B, P, C] + + @T.prim_func + def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [128, 128], dtype="float32", offset_factor=1) + B = T.match_buffer(var_B, [128, 128], dtype="float32", offset_factor=1) + P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1) + C = T.match_buffer(var_C, [128, 128], dtype="float32", offset_factor=1) + with T.block("C"): + T.reads(A[0:128, 0:128], B[0:128, 0:128], P[0]) + T.writes(C[0:128, 0:128]) + T.call_extern("myfunc", A.data, B.data, C.data, P[0], dtype="") + + _check_workload(te_extern, tir_extern) + + if __name__ == "__main__": tvm.testing.main()