Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, functions=None, type_definitions=None, attrs=None, global_inf
global_infos,
)

def clone(self) -> "IRModule":
return _ffi_api.Module_Clone(self)

def functions_items(self):
"""Get items in self.functions.items() in alphabetical order.

Expand Down Expand Up @@ -138,6 +141,12 @@ def __getitem__(self, var):
return _ffi_api.Module_Lookup(self, var)
return _ffi_api.Module_LookupDef(self, var)

def __delitem__(self, var: Union[str, _expr.GlobalVar]):
_ffi_api.Module_Remove(self, var)

def __contains__(self, var: Union[str, _expr.GlobalVar]) -> bool:
return _ffi_api.Module_Contains(self, var)

def update(self, other):
"""Insert functions in another Module to current one.

Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,8 @@ def RunCodegen(
The registered pass to remove unused functions.
"""
if entry_functions is None:
entry_functions = ["main"]
entry_functions = []

# enable cutlass byoc registries
# pylint: disable=unused-import,import-outside-toplevel
from tvm.contrib import cutlass as _cutlass
Expand Down
34 changes: 34 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ TVM_REGISTER_GLOBAL("ir.IRModule")
return IRModule(funcs, types, {}, {}, dict_attrs, global_infos);
});

TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule {
IRModule clone = mod;
clone.CopyOnWrite();
return clone;
});

TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
ICHECK(val->IsInstance<RelayExprNode>());
Expand All @@ -423,6 +429,34 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
return mod;
});

TVM_REGISTER_GLOBAL("ir.Module_Remove")
.set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> IRModule {
GlobalVar gvar = [&]() {
if (auto opt = var.as<GlobalVar>()) {
return opt.value();
} else if (auto opt = var.as<String>()) {
return mod->GetGlobalVar(opt.value());
} else {
LOG(FATAL) << "InternalError: "
<< "Variant didn't contain any of the allowed types";
}
}();
mod->Remove(gvar);
return mod;
});

TVM_REGISTER_GLOBAL("ir.Module_Contains")
.set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> bool {
if (auto opt = var.as<GlobalVar>()) {
return mod->functions.count(opt.value());
} else if (auto opt = var.as<String>()) {
return mod->global_var_map_.count(opt.value());
} else {
LOG(FATAL) << "InternalError: "
<< "Variant didn't contain any of the allowed types";
}
});

TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);

TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
Expand Down
40 changes: 34 additions & 6 deletions src/relax/transform/run_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <iostream>

#include "../../support/ordered_set.h"
#include "utils.h"

namespace tvm {
Expand All @@ -39,12 +40,39 @@ class CodeGenRunner : ExprMutator {

explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}

IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String> entry_functions) {
IRModule Run(Optional<Map<String, OptionMap>> target_options,
Array<String> entry_function_names) {
IRModule mod = builder_->GetContextIRModule();
for (const String& entry_func_name : entry_functions) {
auto entry_func = mod->Lookup(entry_func_name);
auto gvar = mod->GetGlobalVar(entry_func_name);
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(entry_func)));

support::OrderedSet<GlobalVar> entry_functions;
// Any user-provided functions are treated as entry functions.
for (const auto& name : entry_function_names) {
entry_functions.insert(mod->GetGlobalVar(name));
}

// In addtion, any externally-exposed function that does not
// belong to a specific codegen may be an entry function. These
// are added in alphabetical order, to ensure consistent order of
// evaluation for debug/test purposes.
{
std::vector<GlobalVar> attr_entry_functions;
for (const auto& [gv, func] : mod->functions) {
if (func->GetLinkageType() == LinkageType::kExternal &&
!func->GetAttr<String>(attr::kCodegen) && func->IsInstance<relax::FunctionNode>()) {
attr_entry_functions.push_back(gv);
}
}
std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
[](const auto& gvar_a, const auto& gvar_b) {
return gvar_a->name_hint > gvar_b->name_hint;
});
for (const auto& gvar : attr_entry_functions) {
entry_functions.insert(gvar);
}
}

for (const auto& gvar : entry_functions) {
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
}

auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
Expand All @@ -65,7 +93,7 @@ class CodeGenRunner : ExprMutator {
}

// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this.
return DeadCodeElimination(out_mod, entry_functions);
return DeadCodeElimination(out_mod, entry_function_names);
}

using ExprMutator::VisitExpr_;
Expand Down
43 changes: 36 additions & 7 deletions tests/python/relax/test_transform_codegen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@
dev = tvm.cuda()


def check_executable(exec, dev, inputs, expected):
def check_executable(exec, dev, inputs, expected, entry_func_name):
vm = relax.VirtualMachine(exec, dev)
out = vm["main"](*inputs)
out = vm[entry_func_name](*inputs)
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)


def check_roundtrip(exec0, dev, inputs, expected):
def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
exec0.mod.export_library("exec.so")
exec1 = tvm.runtime.load_module("exec.so")
os.remove("exec.so")
assert exec0.stats() == exec1["stats"]()
assert exec0.as_text() == exec1["as_text"]()

check_executable(exec0, dev, inputs, expected)
check_executable(exec1, dev, inputs, expected)
check_executable(exec0, dev, inputs, expected, entry_func_name)
check_executable(exec1, dev, inputs, expected, entry_func_name)


def gen_ground_truth(mod, target, dev, inputs):
Expand Down Expand Up @@ -113,10 +113,17 @@ def setup_test():
return mod, inputs, expected


entry_func_name = tvm.testing.parameter("main", "func")


@tvm.testing.requires_gpu
def test_tensorrt_only():
def test_tensorrt_only(entry_func_name):
mod, inputs, expected = setup_test()

if entry_func_name != "main":
mod[entry_func_name] = mod
del mod["main"]

# Define patterns that we want to offload to byoc
# This test will offload entire model
# Thus, define patterns for both `multiply` and `add` ops
Expand All @@ -135,7 +142,7 @@ def test_tensorrt_only():

ex0 = relax.build(new_mod, target, params={})
# Sanity check for the correctness and roundtrip
check_roundtrip(ex0, dev, inputs, expected)
check_roundtrip(ex0, dev, inputs, expected, entry_func_name)


@tvm.testing.requires_gpu
Expand Down Expand Up @@ -248,6 +255,28 @@ def test_multiple_calls_same_extern():
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])


def test_default_entry_func():
"""The entry function is not necessarily named "main"

Like `test_multiple_calls_same_extern`, but the main function is
named "func".
"""
before_with_main = Conv2dx2
after_with_main = relax.transform.RunCodegen()(before_with_main)

def rename_main(mod):
mod = mod.clone()
mod["func"] = mod["main"].with_attr("global_symbol", "func")
del mod["main"]
return mod

before_with_func = rename_main(before_with_main)
expected_with_func = rename_main(after_with_main)
after_with_func = relax.transform.RunCodegen()(before_with_func)

tvm.ir.assert_structural_equal(expected_with_func["func"], after_with_func["func"])


def test_dynamic_shape():
import tvm.relax.backend.contrib.cublas

Expand Down