Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ class ProducerToBufferTransformer : public StmtExprMutator {
const std::unordered_map<te::Tensor, Buffer>& tensor2buffers_;
};

/*! \brief The helper mutator to rewrite buffer and buffer var accessed by block body */
class BufferSubstituter : public StmtExprMutator {
public:
explicit BufferSubstituter(const std::unordered_map<const VarNode*, PrimExpr>& var_map,
const std::unordered_map<const BufferNode*, Buffer>& buffer_map)
: var_map_(var_map), buffer_map_(buffer_map) {}

PrimExpr VisitExpr_(const VarNode* op) final {
auto it = var_map_.find(op);
if (it != var_map_.end()) {
return it->second;
}
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_map_.find(load->buffer.get());
if (it != buffer_map_.end()) {
return BufferLoad(it->second, load->indices, load->span);
}
return load;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_map_.find(store->buffer.get());
if (it != buffer_map_.end()) {
return BufferStore(it->second, store->value, store->indices, store->span);
}
return store;
}

private:
const std::unordered_map<const VarNode*, PrimExpr>& var_map_;
const std::unordered_map<const BufferNode*, Buffer>& buffer_map_;
};

/*! \brief Helper data structure to store information. */
struct CreateFuncInfo {
/*! \brief The Tensor arg_list. */
Expand Down Expand Up @@ -364,13 +402,15 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* info) {
// Step 1. Check all inputs are visited before and update var_map.
std::unordered_map<const VarNode*, PrimExpr> var_map;
std::unordered_map<const BufferNode*, Buffer> input_buffer_map;
ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size());
for (size_t i = 0; i < extern_op->inputs.size(); ++i) {
const Buffer& placeholder = extern_op->input_placeholders[i];
const te::Tensor& input_tensor = extern_op->inputs[i];
auto it = info->tensor2buffers.find(input_tensor);
ICHECK(it != info->tensor2buffers.end());
var_map[placeholder->data.get()] = it->second->data;
input_buffer_map[placeholder.get()] = it->second;
}

// Step 2. Update info with its output tensor and placeholder buffer.
Expand All @@ -394,7 +434,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
writes.push_back(BufferRegion::FullRegion(buffer));
}

Stmt body = Substitute(extern_op->body, var_map);
BufferSubstituter substituter(var_map, input_buffer_map);
Stmt body = substituter(extern_op->body);

// Step 4. Generate opaque block as body.
return BlockRealize(/*iter_values=*/{},
Expand Down
6 changes: 4 additions & 2 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
Array<BufferRegion> writes = op->writes.Map(
std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));

if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
writes.same_as(op->writes)) {
return GetRef<Block>(op);
} else {
ObjectPtr<BlockNode> n = CopyOnWrite(op);
Expand Down Expand Up @@ -238,12 +239,13 @@ class PrimFuncSpecializer : public StmtExprMutator {

BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
auto it = buffer_map_.find(buffer_region->buffer);
const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer;
Array<Range> region = buffer_region->region.Map(
std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
return buffer_region;
} else {
return BufferRegion(it->second, std::move(region));
return BufferRegion(buffer, std::move(region));
}
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,5 +689,35 @@ def test_argmax():
tvm.ir.assert_structural_equal(prim_func, argmax_expected)


def test_extern_with_explicit_buffer_access():
def te_extern():
A = te.placeholder((128, 128), name="A")
B = te.placeholder((128, 128), name="B")
P = te.placeholder((1,), name="P")
C = te.extern(
(128, 128),
[A, B, P],
lambda ins, outs: tvm.tir.call_extern(
"", "myfunc", ins[0].data, ins[1].data, outs[0].data, ins[2][0]
),
name="C",
)
return [A, B, P, C]

@T.prim_func
def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(var_A, [128, 128], dtype="float32", offset_factor=1)
B = T.match_buffer(var_B, [128, 128], dtype="float32", offset_factor=1)
P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1)
C = T.match_buffer(var_C, [128, 128], dtype="float32", offset_factor=1)
with T.block("C"):
T.reads(A[0:128, 0:128], B[0:128, 0:128], P[0])
T.writes(C[0:128, 0:128])
T.call_extern("myfunc", A.data, B.data, C.data, P[0], dtype="")

_check_workload(te_extern, tir_extern)


if __name__ == "__main__":
tvm.testing.main()