diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c index a40c1d530fa9..ae007037e6cc 100644 --- a/src/runtime/crt/aot_executor/aot_executor.c +++ b/src/runtime/crt/aot_executor/aot_executor.c @@ -83,7 +83,7 @@ int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) { } int TVMAotExecutor_Run(TVMAotExecutor* executor) { - const char* tvm_main_suffix = "___tvm_main__"; + const char* tvm_main_suffix = "_run"; char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; { @@ -203,17 +203,6 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, TVMNDArray_IncrementReference(array); } - for (i = 0; i < md->num_workspace_pools; ++i) { - LOG_DEBUG("pools allocate[%d]: %s\n", i, md->workspace_pools[i].name); - - status = TVMNDArray_Empty(md->workspace_pools[i].num_shape, md->workspace_pools[i].shape, - md->workspace_pools[i].dtype, executor->device, - &executor->args[arg_idx++]); - if (status != 0) { - return status; - } - } - CHECK_EQ(0, md->num_constant_pools, "Constant pools not supported"); return status; } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ccc15fc1ee49..ee5a7cd33de9 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -929,11 +929,21 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod relay::backend::ExecutorCodegenMetadata metadata, runtime::metadata::Metadata aot_metadata) { Array final_modules(modules); - if (aot_metadata.defined()) { - final_modules.push_back(CreateAotMetadataModule(aot_metadata, true)); + Array func_names; + + if (metadata.defined()) { + if (metadata->executor == "aot") { + if (aot_metadata.defined()) { + final_modules.push_back(CreateAotMetadataModule(aot_metadata, true)); + } + + // add the run function (typically "tvmgen_default_run") to function registry + // when using AOT executor + std::string run_func = runtime::get_name_mangled(metadata->mod_name, "run"); + func_names.push_back(run_func); + } } - Array func_names; for (runtime::Module mod : final_modules) { auto pf_funcs = mod.GetFunction("get_func_names"); if (pf_funcs != nullptr) { diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 83fa91af06c9..3309aad0a5db 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -229,15 +229,10 @@ def do_test(): do_test() -enable_usmp, expect_exception = tvm.testing.parameters((True, True), (False, False)) - - @tvm.testing.requires_micro -def test_aot_executor_usmp_const_pool(enable_usmp, expect_exception): - """Test the AOT executor with microTVM using usmp. - Test should fail if const pool is supplied to executor - as these are currently not supported - """ +def test_aot_executor_usmp_const_pool(): + """Test the AOT executor with microTVM using USMP to generate a constant data pool.""" + ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace-usmp") if ws_root.exists(): shutil.rmtree(ws_root) @@ -260,7 +255,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8], %c : Tensor[(1 C_np = np.array([[8, 9]], dtype="uint8").astype(type_dict["c"]) params = {"c": C_np} with tvm.transform.PassContext( - opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": enable_usmp} + opt_level=3, config={"tir.disable_vectorize": True, "tir.usmp.enable": True} ): factory = tvm.relay.build( relay_mod, @@ -278,10 +273,7 @@ def do_test(): ) ) except tvm._ffi.base.TVMError as e: - if expect_exception: - return - else: - raise e + raise e assert aot_executor.get_input_index("a") == 0 assert aot_executor.get_input_index("b") == 1 diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index e664c2ebb858..39919f337197 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -618,7 +618,6 @@ def test_multiple_relay_modules_aot_graph(): assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib1.c")) - assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod1_lib2.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib0.c")) assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "mod2_lib1.c"))