Skip to content

Commit 147ed5e

Browse files
authored
[Unity][CodeGen] RunCodegen based on externally-exposed functions (#16422)
* [IR] Add utility methods to IRModule * `IRModule.clone`: Clone the module. While in C++, a module can be copied using `IRModule::CopyOnWrite()`, copying a module in Python required passing all members into the `IRModule` initializer. The `IRModule.clone` method provides an easier way to copy an `IRModule` from python. * `IRModule.__delitem__`: Remove a function from the module. This exposes the C++ method `IRModuleNode::Remove` for use in the python API. This uses the python `del` keyword, similar to a native python list. Similar to the existing `IRModule.__getitem__`, this can be called with either a `GlobalVar` or a python string. * `IRModule.__contains__`: Check if a function is in the module. This allows the pythone keyword `in` to check if a module contains a specific function. Similar to the existing `IRModule.__getitem__`, this can be called either with a `GlobalVar` (`if gvar in mod`) or with a python string (`if "function_name" in mod`). * [Unity][CodeGen] RunCodegen based on externally-exposed functions Prior to this commit, `relax.transform.RunCodegen` required a list of entry functions for a module, defaulting to `"main"` if not specified. The list of entry functions is duplicate information that could be inferred from the module, and should not be required from the user. This commit updates `RunCodegen` to treat all externally-exposed functions as entry points, in the same manner as `DeadCodeElimination`. For backwards compatibility, the `entry_functions` argument is still accepted, and is used to augment the list of externally-exposed functions.
1 parent ff9ebad commit 147ed5e

File tree

5 files changed

+115
-14
lines changed

5 files changed

+115
-14
lines changed

python/tvm/ir/module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def __init__(self, functions=None, type_definitions=None, attrs=None, global_inf
8080
global_infos,
8181
)
8282

83+
def clone(self) -> "IRModule":
84+
return _ffi_api.Module_Clone(self)
85+
8386
def functions_items(self):
8487
"""Get items in self.functions.items() in alphabetical order.
8588
@@ -138,6 +141,12 @@ def __getitem__(self, var):
138141
return _ffi_api.Module_Lookup(self, var)
139142
return _ffi_api.Module_LookupDef(self, var)
140143

144+
def __delitem__(self, var: Union[str, _expr.GlobalVar]):
145+
_ffi_api.Module_Remove(self, var)
146+
147+
def __contains__(self, var: Union[str, _expr.GlobalVar]) -> bool:
148+
return _ffi_api.Module_Contains(self, var)
149+
141150
def update(self, other):
142151
"""Insert functions in another Module to current one.
143152

python/tvm/relax/transform/transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ def RunCodegen(
574574
The registered pass to remove unused functions.
575575
"""
576576
if entry_functions is None:
577-
entry_functions = ["main"]
577+
entry_functions = []
578+
578579
# enable cutlass byoc registries
579580
# pylint: disable=unused-import,import-outside-toplevel
580581
from tvm.contrib import cutlass as _cutlass

src/ir/module.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,12 @@ TVM_REGISTER_GLOBAL("ir.IRModule")
413413
return IRModule(funcs, types, {}, {}, dict_attrs, global_infos);
414414
});
415415

416+
TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule {
417+
IRModule clone = mod;
418+
clone.CopyOnWrite();
419+
return clone;
420+
});
421+
416422
TVM_REGISTER_GLOBAL("ir.Module_Add")
417423
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
418424
ICHECK(val->IsInstance<RelayExprNode>());
@@ -423,6 +429,34 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
423429
return mod;
424430
});
425431

432+
TVM_REGISTER_GLOBAL("ir.Module_Remove")
433+
.set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> IRModule {
434+
GlobalVar gvar = [&]() {
435+
if (auto opt = var.as<GlobalVar>()) {
436+
return opt.value();
437+
} else if (auto opt = var.as<String>()) {
438+
return mod->GetGlobalVar(opt.value());
439+
} else {
440+
LOG(FATAL) << "InternalError: "
441+
<< "Variant didn't contain any of the allowed types";
442+
}
443+
}();
444+
mod->Remove(gvar);
445+
return mod;
446+
});
447+
448+
TVM_REGISTER_GLOBAL("ir.Module_Contains")
449+
.set_body_typed([](IRModule mod, Variant<String, GlobalVar> var) -> bool {
450+
if (auto opt = var.as<GlobalVar>()) {
451+
return mod->functions.count(opt.value());
452+
} else if (auto opt = var.as<String>()) {
453+
return mod->global_var_map_.count(opt.value());
454+
} else {
455+
LOG(FATAL) << "InternalError: "
456+
<< "Variant didn't contain any of the allowed types";
457+
}
458+
});
459+
426460
TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
427461

428462
TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")

src/relax/transform/run_codegen.cc

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <iostream>
3030

31+
#include "../../support/ordered_set.h"
3132
#include "utils.h"
3233

3334
namespace tvm {
@@ -39,12 +40,39 @@ class CodeGenRunner : ExprMutator {
3940

4041
explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
4142

42-
IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String> entry_functions) {
43+
IRModule Run(Optional<Map<String, OptionMap>> target_options,
44+
Array<String> entry_function_names) {
4345
IRModule mod = builder_->GetContextIRModule();
44-
for (const String& entry_func_name : entry_functions) {
45-
auto entry_func = mod->Lookup(entry_func_name);
46-
auto gvar = mod->GetGlobalVar(entry_func_name);
47-
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(entry_func)));
46+
47+
support::OrderedSet<GlobalVar> entry_functions;
48+
// Any user-provided functions are treated as entry functions.
49+
for (const auto& name : entry_function_names) {
50+
entry_functions.insert(mod->GetGlobalVar(name));
51+
}
52+
53+
// In addtion, any externally-exposed function that does not
54+
// belong to a specific codegen may be an entry function. These
55+
// are added in alphabetical order, to ensure consistent order of
56+
// evaluation for debug/test purposes.
57+
{
58+
std::vector<GlobalVar> attr_entry_functions;
59+
for (const auto& [gv, func] : mod->functions) {
60+
if (func->GetLinkageType() == LinkageType::kExternal &&
61+
!func->GetAttr<String>(attr::kCodegen) && func->IsInstance<relax::FunctionNode>()) {
62+
attr_entry_functions.push_back(gv);
63+
}
64+
}
65+
std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
66+
[](const auto& gvar_a, const auto& gvar_b) {
67+
return gvar_a->name_hint > gvar_b->name_hint;
68+
});
69+
for (const auto& gvar : attr_entry_functions) {
70+
entry_functions.insert(gvar);
71+
}
72+
}
73+
74+
for (const auto& gvar : entry_functions) {
75+
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
4876
}
4977

5078
auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
@@ -65,7 +93,7 @@ class CodeGenRunner : ExprMutator {
6593
}
6694

6795
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this.
68-
return DeadCodeElimination(out_mod, entry_functions);
96+
return DeadCodeElimination(out_mod, entry_function_names);
6997
}
7098

7199
using ExprMutator::VisitExpr_;

tests/python/relax/test_transform_codegen_pass.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,21 @@
4848
dev = tvm.cuda()
4949

5050

51-
def check_executable(exec, dev, inputs, expected):
51+
def check_executable(exec, dev, inputs, expected, entry_func_name):
5252
vm = relax.VirtualMachine(exec, dev)
53-
out = vm["main"](*inputs)
53+
out = vm[entry_func_name](*inputs)
5454
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
5555

5656

57-
def check_roundtrip(exec0, dev, inputs, expected):
57+
def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
5858
exec0.mod.export_library("exec.so")
5959
exec1 = tvm.runtime.load_module("exec.so")
6060
os.remove("exec.so")
6161
assert exec0.stats() == exec1["stats"]()
6262
assert exec0.as_text() == exec1["as_text"]()
6363

64-
check_executable(exec0, dev, inputs, expected)
65-
check_executable(exec1, dev, inputs, expected)
64+
check_executable(exec0, dev, inputs, expected, entry_func_name)
65+
check_executable(exec1, dev, inputs, expected, entry_func_name)
6666

6767

6868
def gen_ground_truth(mod, target, dev, inputs):
@@ -113,10 +113,17 @@ def setup_test():
113113
return mod, inputs, expected
114114

115115

116+
entry_func_name = tvm.testing.parameter("main", "func")
117+
118+
116119
@tvm.testing.requires_gpu
117-
def test_tensorrt_only():
120+
def test_tensorrt_only(entry_func_name):
118121
mod, inputs, expected = setup_test()
119122

123+
if entry_func_name != "main":
124+
mod[entry_func_name] = mod
125+
del mod["main"]
126+
120127
# Define patterns that we want to offload to byoc
121128
# This test will offload entire model
122129
# Thus, define patterns for both `multiply` and `add` ops
@@ -135,7 +142,7 @@ def test_tensorrt_only():
135142

136143
ex0 = relax.build(new_mod, target, params={})
137144
# Sanity check for the correctness and roundtrip
138-
check_roundtrip(ex0, dev, inputs, expected)
145+
check_roundtrip(ex0, dev, inputs, expected, entry_func_name)
139146

140147

141148
@tvm.testing.requires_gpu
@@ -248,6 +255,28 @@ def test_multiple_calls_same_extern():
248255
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
249256

250257

258+
def test_default_entry_func():
259+
"""The entry function is not necessarily named "main"
260+
261+
Like `test_multiple_calls_same_extern`, but the main function is
262+
named "func".
263+
"""
264+
before_with_main = Conv2dx2
265+
after_with_main = relax.transform.RunCodegen()(before_with_main)
266+
267+
def rename_main(mod):
268+
mod = mod.clone()
269+
mod["func"] = mod["main"].with_attr("global_symbol", "func")
270+
del mod["main"]
271+
return mod
272+
273+
before_with_func = rename_main(before_with_main)
274+
expected_with_func = rename_main(after_with_main)
275+
after_with_func = relax.transform.RunCodegen()(before_with_func)
276+
277+
tvm.ir.assert_structural_equal(expected_with_func["func"], after_with_func["func"])
278+
279+
251280
def test_dynamic_shape():
252281
import tvm.relax.backend.contrib.cublas
253282

0 commit comments

Comments
 (0)