Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,79 @@
namespace tvm {
namespace tir {

PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const std::unordered_set<const GlobalVarNode*>& external_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(external_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
}
}

private:
explicit SubroutineCallRewriter(const std::unordered_set<const GlobalVarNode*>& external_methods)
: external_methods_(external_methods) {}

PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));

if (auto gvar = node->op.as<GlobalVarNode>()) {
if (external_methods_.count(gvar)) {
Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr {
if (auto* as_call = arg.as<CallNode>()) {
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
PrimExpr data_ptr = as_call->args[0];
made_change_ = true;
return data_ptr;
}
}
return arg;
});
if (!args.same_as(node->args)) {
node.CopyOnWrite()->args = args;
}
}
}

return std::move(node);
}
const std::unordered_set<const GlobalVarNode*>& external_methods_;
bool made_change_{false};
};

PrimFunc MakeUnpackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return func;
}
}

// Internal function calls do not need API updates
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute";
if (!global_symbol.defined()) {
return func;
}

auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute";
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget ("
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();

auto* func_ptr = func.CopyOnWrite();

// Setup device context
int target_device_type = target.value()->GetTargetDeviceType();
Integer device_type(target_device_type);
Integer device_id(0);
PrimExpr node = StringImm("default");
ObjectRef node = String("default");
const Stmt nop = Evaluate(0);
std::vector<Stmt> device_init;

Expand Down Expand Up @@ -82,31 +141,43 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return std::move(func);
return func;
}

namespace transform {

Pass MakeUnpackedAPI() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
auto pass_func = [](IRModule mod, PassContext ctx) {
std::unordered_set<const GlobalVarNode*> external_methods;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto* prim_func = base_func.as<PrimFuncNode>()) {
if (prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
external_methods.insert(gvar.get());
}
}
}

IRModule updates;

for (const auto& kv : mptr->functions) {
if (auto opt = kv.second.as<PrimFunc>()) {
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
auto updated_func = MakeUnpackedAPI(std::move(func));
updates.push_back({kv.first, updated_func});

if (auto body = SubroutineCallRewriter::Apply(external_methods, func->body)) {
func.CopyOnWrite()->body = body.value();
}

func = MakeUnpackedAPI(std::move(func));
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
}
}

for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return m;
return mod;
};

return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {});
Expand Down
158 changes: 152 additions & 6 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import pytest

import tvm
from tvm import te
from tvm import te, tir
from tvm.script import tir as T, ir as I
import numpy


Expand All @@ -39,17 +40,20 @@ def mod(mod_without_attrs):
return mod


def test_fails_if_not_global_symbol(mod_without_attrs):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
def test_noop_if_not_global_symbol(mod_without_attrs):
before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm")))(
mod_without_attrs
)
with pytest.raises(tvm.TVMError, match="Expect PrimFunc to have the global_symbol attribute"):
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)


def test_fails_if_no_target(mod_without_attrs):
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod_without_attrs)
with pytest.raises(tvm.TVMError, match="Require the target attribute"):
with pytest.raises(
tvm.TVMError,
match="MakeUnpackedAPI required the function to be annotated with tvm::attr::kTarget",
):
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]


Expand Down Expand Up @@ -134,5 +138,147 @@ def test_body():
assert f.params[2].name == "A"


class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification

A subroutine without the "global_symbol" attribute is an internal
subroutine, and is not directly exposed to a user of the generated
`runtime.Module`.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"target": T.target("llvm")})
T.evaluate(A_data)

return mod


class TestSubroutineCallToExternallyVisibleSubroutine(tvm.testing.CompareBeforeAfter):
"""Externally-visible subroutines should be updated

Subroutines that are exposed externally should be updated by
MakeUnpackedAPI.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(A.data)

@T.prim_func
def subroutine(A_data: T.handle("float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)

return mod


class TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeAfter):
"""Callsites of externally-visible subroutines may require updates

The MakeUnpackedAPI transform lowers all buffers into a data
pointer to a primitive type. If a subroutine call is currently
passing a DLTensor produced by `T.tvm_make_stack_array` into the
subroutine, the callsite should be updated to instead pass the
data pointer directly.
"""

transform = tvm.tir.transform.MakeUnpackedAPI()

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
T.tvm_stack_make_shape(1, dtype="handle"),
T.reinterpret(T.uint64(0), dtype="handle"),
T.uint32(1),
T.Cast("float32", 0),
0,
dtype="handle",
)
)

@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A.data)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
T.evaluate(A_data)

return mod


if __name__ == "__main__":
tvm.testing.main()