Skip to content

Commit

Permalink
[AOT] Name mangling in AOT (apache#8014)
Browse files Browse the repository at this point in the history
* [AOT] Name mangling in AOT

Mini-RFC is here: https://discuss.tvm.apache.org/t/mini-rfc-name-mangling-in-aot

With this change we'll mangle the name of global symbols so that we can bundle
together multiple models in the same application.

The relay.build interface has been left unchanged, which means I am
resuing mod_name as a prefix for all functions. If mod_name is None then
a "_tvm" prefix is used.

I had to add two different compilation functions:
- _CompileEngineLowerWithModuleName to mangle all the operators with the mod_name
- PartitionGraphWithModName to mangle all the operators produced by BYOC

I could have changed signature of both, but that would have meant a very
invasive refactoring.

I refactored the aot test utils and added some tests for multiple
models.

Change-Id: I30e93fa075f660054577ea36cf9268ec0c6eebcb

* retrigger CI

Change-Id: I4f11da7fce1327ad89bb25f25209b57077b2c6a3
  • Loading branch information
Giuseppe Rossini authored and ylc committed Jan 13, 2022
1 parent 108c1dc commit 86fee54
Show file tree
Hide file tree
Showing 33 changed files with 600 additions and 195 deletions.
4 changes: 2 additions & 2 deletions apps/microtvm/zephyr/aot_demo/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#define WORKSPACE_SIZE (270 * 1024)

static uint8_t g_aot_memory[WORKSPACE_SIZE];
extern tvm_model_t network;
extern tvm_model_t tvmgen_default_network;
tvm_workspace_t app_workspace;

// Wakeup sequence used to wake up QEMU on the host.
Expand Down Expand Up @@ -205,7 +205,7 @@ void main(void) {

double elapsed_time = 0;
TVMPlatformTimerStart();
int ret_val = tvm_runtime_run(&network, inputs, outputs);
int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs);
TVMPlatformTimerStop(&elapsed_time);

if (ret_val != 0) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief The main AOT executor function */
constexpr const char* tvm_run_func_prefix = "tvm__run_func";
constexpr const char* tvm_run_func_suffix = "run_model";
} // namespace symbol

// implementations of inline functions.
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def _populate_codegen_dir(mod, codegen_dir: str):
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
"""Populate the codegen sub-directory as part of a Model Library Format export.
Parameters
Expand All @@ -44,6 +44,9 @@ def _populate_codegen_dir(mod, codegen_dir: str):
Module which should be written to codegen_dir.
codegen_dir : str
Path to the codegen directory on disk.
module_name: Optional[str]
Name used to prefix the generated source files
"""
dso_modules = mod._collect_dso_modules()
dso_module_handles = [m.handle.value for m in dso_modules]
Expand All @@ -55,17 +58,19 @@ def _populate_codegen_dir(mod, codegen_dir: str):

mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
lib_name = f"{module_name}_lib" if module_name else "lib"

for dso_mod in dso_modules:
if dso_mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.c")
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.o")
else:
assert (
False
Expand Down Expand Up @@ -98,7 +103,6 @@ def _build_sid_map(graph_json):
A list with one entry per storage id describing that memory.
"""
graph = json.loads(graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
Expand Down Expand Up @@ -227,7 +231,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
runtime = ["aot"] if is_aot else ["graph"]

metadata = {
"version": 2,
"version": 3,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod),
Expand All @@ -240,7 +244,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
from ..backend.utils import mangle_module_name
from .. import function as _function
from .. import ty as _ty
from . import _backend
Expand Down Expand Up @@ -328,7 +329,7 @@ class CompileEngine(Object):
def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")

def lower(self, source_func, target=None):
def lower(self, source_func, target=None, mod_name="default"):
"""Lower a source_func to a CachedFunc.
Parameters
Expand All @@ -346,8 +347,9 @@ def lower(self, source_func, target=None):
"""
# pylint: disable=broad-except, import-outside-toplevel
try:
mod_name = mangle_module_name(mod_name)
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key)
return _backend._CompileEngineLower(self, key, mod_name)
except Exception:
import traceback

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/graph_executor_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tvm.relay import _build_module
from tvm.target import Target
from tvm.tir import expr as _expr
from .utils import mangle_module_name


class GraphExecutorCodegen(object):
Expand Down Expand Up @@ -80,7 +81,8 @@ def codegen(self, func):
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self._codegen(func)
default_mod_name = mangle_module_name("default")
self._codegen(func, default_mod_name)
graph_json = self._get_graph_json()
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/backend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Utility backend functions."""


def _is_valid_modname(mod_name):
"""Determine if mod_name is a valid string to use inside function names"""
if mod_name:
try:
mod_name.encode("ascii")
return True
except UnicodeEncodeError:
return False

return True


def mangle_module_name(mod_name):
if not _is_valid_modname(mod_name):
raise ValueError(mod_name + " contains invalid characters")
if mod_name:
return "tvmgen_" + mod_name
return "tvmgen"
15 changes: 12 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import expr as _expr
from . import function as _function
from .transform import InferType
from .backend.utils import mangle_module_name
from .backend import executor_factory as _executor_factory
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
Expand Down Expand Up @@ -85,7 +86,9 @@ def __init__(self):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]

def build(self, mod, target=None, target_host=None, params=None, executor="graph"):
def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
):
"""
Parameters
----------
Expand Down Expand Up @@ -115,6 +118,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
- If "graph" is specified, then the graph_executor will be used
- If "aot" is specified, then the aot_executor will be used
mod_name: Optional[str]
The module name we will build
Returns
-------
graph_json : str
Expand Down Expand Up @@ -145,7 +151,9 @@ def build(self, mod, target=None, target_host=None, params=None, executor="graph
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler

self._build(mod, target, target_host, executor)
mod_name = mangle_module_name(mod_name)

self._build(mod, target, target_host, executor, mod_name)
autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent

# Get artifacts
Expand Down Expand Up @@ -295,6 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
"""
# pylint: enable=line-too-long
# fmt: on

if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

Expand Down Expand Up @@ -330,7 +339,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
with tophub_context:
bld_mod = BuildModule()
executor_config, runtime_mod, params = bld_mod.build(
mod=ir_mod, target=target, params=params, executor=executor
mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name
)
func_metadata = bld_mod.get_function_metadata()

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm.runtime import ndarray as _nd

from . import _ffi_api
from ..backend.utils import mangle_module_name


def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None):
Expand Down Expand Up @@ -713,7 +714,7 @@ def LambdaLift():
return _ffi_api.LambdaLift()


def PartitionGraph():
def PartitionGraph(mod_name="default"):
"""Partition a Relay program into regions that can be executed on different
backends.
Expand All @@ -722,7 +723,8 @@ def PartitionGraph():
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
mod_name = mangle_module_name(mod_name)
return _ffi_api.PartitionGraph(mod_name)


def AnnotateTarget(targets, include_non_call_ops=True):
Expand Down
30 changes: 20 additions & 10 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,12 @@ class AOTExecutorCodegen : public ExprVisitor {
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

Expand Down Expand Up @@ -410,7 +411,7 @@ class AOTExecutorCodegen : public ExprVisitor {
target = targets_[call_dev_type];
}
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
Expand Down Expand Up @@ -533,7 +534,10 @@ class AOTExecutorCodegen : public ExprVisitor {

// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", runtime::String(runtime::symbol::tvm_run_func_prefix));
String run_func_name =
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

// Make the PrimFunc
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
Expand Down Expand Up @@ -586,6 +590,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
IntegerArray return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
Expand All @@ -595,10 +601,11 @@ class AOTExecutorCodegen : public ExprVisitor {
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
compile_engine_(CompileEngine::Global()) {}

LoweredOutput Codegen(relay::Function func) {
LoweredOutput Codegen(relay::Function func, String mod_name) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
mod_name_ = mod_name;

for (auto input : func->params) {
input_vars_.push_back(input);
Expand Down Expand Up @@ -645,15 +652,15 @@ class AOTExecutorCodegen : public ExprVisitor {
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
} else {
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata =
runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot);
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
};
Expand All @@ -673,7 +680,8 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
} else if (name == "codegen") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Function func = args[0];
this->output_ = codegen(func);
String mod_name = args[1];
this->output_ = codegen(func, mod_name);
});
} else if (name == "list_params_name") {
return PackedFunc(
Expand Down Expand Up @@ -724,7 +732,9 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
targets, target_host);
}

LoweredOutput codegen(Function func) { return this->codegen_->Codegen(func); }
LoweredOutput codegen(Function func, String mod_name) {
return this->codegen_->Codegen(func, mod_name);
}

Array<runtime::String> list_params_name() {
Array<runtime::String> ret;
Expand Down
Loading

0 comments on commit 86fee54

Please sign in to comment.