Skip to content

Commit 279f39b

Browse files
mehrdadhyongwww
authored andcommitted
[AoT]Add get_input_name function to AoT Module (apache#14071)
* Add get_input_name to C AOT * add get_input_name to AOT C++ * lint * fix bug in AotExecutor
1 parent cb66e6c commit 279f39b

File tree

11 files changed

+111
-21
lines changed

11 files changed

+111
-21
lines changed

include/tvm/runtime/crt/aot_executor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor);
9292
*/
9393
int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name);
9494

95+
/*!
96+
* \brief Return a pointer to name of input with the specified input index
97+
*
98+
* \param executor Pointer to executor instance, created by TVMAotExecutor_Create().
99+
* \param index Input index for retrieving name.
100+
* \param name Output for retrieving name.
101+
* \return Pointer to input name in `name`.
102+
*/
103+
int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name);
104+
95105
/*!
96106
* \brief Run the generated program.
97107
*

python/tvm/micro/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .session import (
3030
create_local_graph_executor,
3131
create_local_debug_executor,
32+
create_local_aot_executor,
3233
Session,
3334
SessionTerminatedError,
3435
)

python/tvm/micro/session.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import pathlib
2525
import shutil
2626
from typing import Union
27+
28+
from tvm.runtime.executor.aot_executor import AotModule
2729
from ..error import register_error
2830
from .._ffi import get_global_func, register_func
2931
from ..contrib import graph_executor
@@ -259,6 +261,22 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
259261
)
260262

261263

264+
def create_local_aot_executor(session: Session):
265+
"""Create a local AoT executor driving execution on the remote CPU device given.
266+
267+
Parameters
268+
----------
269+
session : Session
270+
A microTVM device session.
271+
272+
Returns
273+
-------
274+
tvm.runtime.executor.aot_executor.AotModule :
275+
A local AoT executor instance that executes on the remote device.
276+
"""
277+
return AotModule(session.create_aot_executor())
278+
279+
262280
@register_func("tvm.micro.compile_and_create_micro_session")
263281
def compile_and_create_micro_session(
264282
mod_src_bytes: bytes,

python/tvm/relay/build_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def _make_executor(self, expr=None):
575575
ret_type = self.mod["main"].checked_type.ret_type
576576
if _ty.is_dynamic(ret_type):
577577
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
578-
mod = build(self.mod, target=self.target)
578+
mod = build(self.mod, target=self.target, executor=Executor("aot"))
579579

580580
# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
581581
# generated code.

python/tvm/runtime/executor/aot_executor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, module):
6767
self._get_num_outputs = module["get_num_outputs"]
6868
self._get_input_index = module["get_input_index"]
6969
self._get_num_inputs = module["get_num_inputs"]
70+
self._get_input_name = module["get_input_name"]
7071

7172
def set_input(self, key=None, value=None, **params):
7273
"""Set inputs to the module via kwargs
@@ -180,3 +181,21 @@ def get_output(self, index, out=None):
180181
return out
181182

182183
return self._get_output(index)
184+
185+
def get_input_name(self, index: int) -> str:
186+
"""Return the name of input with index `index`"""
187+
return self._get_input_name(index)
188+
189+
def get_input_info(self):
190+
"""Return the 'shape' and 'dtype' dictionaries of the module."""
191+
self.get_input_name(0)
192+
193+
shape_dict = dict()
194+
dtype_dict = dict()
195+
for ind in range(0, self.get_num_inputs()):
196+
input_name = self.get_input_name(ind)
197+
input_tensor = self.get_input(ind)
198+
shape_dict[input_name] = input_tensor.shape
199+
dtype_dict[input_name] = input_tensor.dtype
200+
201+
return shape_dict, dtype_dict

src/runtime/aot_executor/aot_executor.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ PackedFunc AotExecutor::GetFunction(const std::string& name,
156156
CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
157157
*rv = this->GetInputIndex(tvm::runtime::SanitizeName(args[0].operator String()));
158158
});
159+
} else if (name == "get_input_name") {
160+
return PackedFunc(
161+
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetInputName(args[0]); });
159162
} else {
160163
return PackedFunc();
161164
}
@@ -191,6 +194,11 @@ int AotExecutor::GetInputIndex(const std::string& name) {
191194
return -1;
192195
}
193196

197+
std::string AotExecutor::GetInputName(int index) {
198+
auto inputs = metadata_->inputs();
199+
return inputs[index]->name();
200+
}
201+
194202
int AotExecutor::GetOutputIndex(const std::string& name) {
195203
auto outputs = metadata_->outputs();
196204
for (unsigned int i = 0; i < outputs.size(); i++) {

src/runtime/aot_executor/aot_executor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ class TVM_DLL AotExecutor : public ModuleNode {
6969
*/
7070
int GetInputIndex(const std::string& name);
7171

72+
/*!
73+
* \brief Get the input name given the index of input.
74+
* \param index The index of the input.
75+
* \return The name of input.
76+
*/
77+
std::string GetInputName(int index);
78+
7279
/*!
7380
* \brief Get the output index given the name of output.
7481
* \param name The name of the output.

src/runtime/crt/aot_executor/aot_executor.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
8282
return rv;
8383
}
8484

85+
int TVMAotExecutor_GetInputName(TVMAotExecutor* executor, int index, char** name) {
86+
const TVMMetadata* md = executor->metadata;
87+
*name = md->inputs[index].name;
88+
return 0;
89+
}
90+
8591
int TVMAotExecutor_Run(TVMAotExecutor* executor) {
8692
const char* tvm_main_suffix = "_run";
8793
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];

src/runtime/crt/aot_executor_module/aot_executor_module.c

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,24 @@ int32_t TVMAotExecutorModule_GetInputIndex(TVMValue* args, int* tcodes, int narg
147147
return 0;
148148
}
149149

150+
int32_t TVMAotExecutorModule_GetInputName(TVMValue* args, int* tcodes, int nargs,
151+
TVMValue* ret_values, int* ret_tcodes,
152+
void* resource_handle) {
153+
if (nargs != 1) {
154+
return kTvmErrorFunctionCallNumArguments;
155+
}
156+
157+
char* name;
158+
int ret = TVMAotExecutor_GetInputName(aot_executor.executor, args[0].v_int64, &name);
159+
if (ret < 0) {
160+
return kTvmErrorExecutorModuleNoSuchInput;
161+
}
162+
163+
ret_values[0].v_str = name;
164+
ret_tcodes[0] = kTVMStr;
165+
return 0;
166+
}
167+
150168
int32_t TVMAotExecutorModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs,
151169
TVMValue* ret_values, int* ret_tcodes,
152170
void* resource_handle) {
@@ -191,10 +209,11 @@ static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = {
191209
&TVMAotExecutorModule_Run, // run
192210
&TVMAotExecutorModule_NotImplemented, // set_input (implemented via python wrapper)
193211
&TVMAotExecutorModule_NotImplemented, // share_params (do not implement)
212+
&TVMAotExecutorModule_GetInputName, // get_input_name
194213
};
195214

196215
static const TVMFuncRegistry aot_executor_registry = {
197-
"\x0a\0get_input\0"
216+
"\x0b\0get_input\0"
198217
"get_input_index\0"
199218
"get_input_info\0"
200219
"get_num_inputs\0"
@@ -203,7 +222,8 @@ static const TVMFuncRegistry aot_executor_registry = {
203222
"load_params\0"
204223
"run\0"
205224
"set_input\0"
206-
"share_params\0",
225+
"share_params\0"
226+
"get_input_name\0",
207227
aot_executor_registry_funcs};
208228

209229
tvm_crt_error_t TVMAotExecutorModule_Register() {

tests/python/relay/aot/test_cpp_aot.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
112112
loaded_mod = tvm.runtime.load_module(test_so_path)
113113
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
114114
runner.set_input(**inputs)
115+
116+
assert runner.get_input_name(0) == "data"
117+
shape_dict, dtype_dict = runner.get_input_info()
118+
assert shape_dict == {"data": (1, 3, 64, 64)}
119+
assert dtype_dict == {"data": "uint8"}
120+
115121
runner.run()
116122
assert (runner.get_output(0).numpy() == list(ref_outputs.values())[0]).all()
117123

0 commit comments

Comments
 (0)