Skip to content

Commit 4ddb876

Browse files
authored
[microTVM] Fix host-driven AOT memory workspaces (#13807)
When using host-driven AOT with memory pools enabled, the workspace and constant memory were not properly supported. In order for them to work properly, the _run function (typically tvmgen_default_run()) needed to be called instead of tvmgen_default___tvm_main__() in order to properly setup the memory workspace pointers. fixes #13777
1 parent f7dfef4 commit 4ddb876

File tree

4 files changed

+19
-29
lines changed

4 files changed

+19
-29
lines changed

src/runtime/crt/aot_executor/aot_executor.c

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) {
8383
}
8484

8585
int TVMAotExecutor_Run(TVMAotExecutor* executor) {
86-
const char* tvm_main_suffix = "___tvm_main__";
86+
const char* tvm_main_suffix = "_run";
8787
char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME];
8888

8989
{
@@ -203,17 +203,6 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle,
203203
TVMNDArray_IncrementReference(array);
204204
}
205205

206-
for (i = 0; i < md->num_workspace_pools; ++i) {
207-
LOG_DEBUG("pools allocate[%d]: %s\n", i, md->workspace_pools[i].name);
208-
209-
status = TVMNDArray_Empty(md->workspace_pools[i].num_shape, md->workspace_pools[i].shape,
210-
md->workspace_pools[i].dtype, executor->device,
211-
&executor->args[arg_idx++]);
212-
if (status != 0) {
213-
return status;
214-
}
215-
}
216-
CHECK_EQ(0, md->num_constant_pools, "Constant pools not supported");
217206
return status;
218207
}
219208

src/target/source/source_module.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,11 +929,21 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& mod
929929
relay::backend::ExecutorCodegenMetadata metadata,
930930
runtime::metadata::Metadata aot_metadata) {
931931
Array<runtime::Module> final_modules(modules);
932-
if (aot_metadata.defined()) {
933-
final_modules.push_back(CreateAotMetadataModule(aot_metadata, true));
932+
Array<String> func_names;
933+
934+
if (metadata.defined()) {
935+
if (metadata->executor == "aot") {
936+
if (aot_metadata.defined()) {
937+
final_modules.push_back(CreateAotMetadataModule(aot_metadata, true));
938+
}
939+
940+
// add the run function (typically "tvmgen_default_run") to function registry
941+
// when using AOT executor
942+
std::string run_func = runtime::get_name_mangled(metadata->mod_name, "run");
943+
func_names.push_back(run_func);
944+
}
934945
}
935946

936-
Array<String> func_names;
937947
for (runtime::Module mod : final_modules) {
938948
auto pf_funcs = mod.GetFunction("get_func_names");
939949
if (pf_funcs != nullptr) {

tests/python/unittest/test_crt.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,10 @@ def do_test():
229229
do_test()
230230

231231

232-
enable_usmp, expect_exception = tvm.testing.parameters((True, True), (False, False))
233-
234-
235232
@tvm.testing.requires_micro
236-
def test_aot_executor_usmp_const_pool(enable_usmp, expect_exception):
237-
"""Test the AOT executor with microTVM using usmp.
238-
Test should fail if const pool is supplied to executor
239-
as these are currently not supported
240-
"""
233+
def test_aot_executor_usmp_const_pool():
234+
"""Test the AOT executor with microTVM using USMP to generate a constant data pool."""
235+
241236
ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace-usmp")
242237
if ws_root.exists():
243238
shutil.rmtree(ws_root)
@@ -260,7 +255,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1
260255
C_np = np.array([[8, 9]], dtype="uint8").astype(type_dict["c"])
261256
params = {"c": C_np}
262257
with tvm.transform.PassContext(
263-
opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp}
258+
opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": True}
264259
):
265260
factory = tvm.relay.build(
266261
relay_mod,
@@ -278,10 +273,7 @@ def do_test():
278273
)
279274
)
280275
except tvm._ffi.base.TVMError as e:
281-
if expect_exception:
282-
return
283-
else:
284-
raise e
276+
raise e
285277

286278
assert aot_executor.get_input_index("a") == 0
287279
assert aot_executor.get_input_index("b") == 1

tests/python/unittest/test_micro_model_library_format.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,6 @@ def test_multiple_relay_modules_aot_graph():
618618

619619
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib0.c"))
620620
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib1.c"))
621-
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib2.c"))
622621
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib0.c"))
623622
assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib1.c"))
624623

0 commit comments

Comments
 (0)