Skip to content

Commit a7f71e1

Browse files
committed
[AOT] BugFix of workspace calculation
Following an investigation from apache#10022, it turns out, currently the workspace calculation assumes there would be a single lowered PrimFunc could be produced per primitive Relay Function. However, the exception turned out to be the CMSIS-NN codegen that produces multiple calls/PrimFuncs in the place of a single call to single relay PrimFunc. This commit adds changes to workspace calculation to be done on lowered IRModule. Additionally, changes the test utils to not to generate any stack allocator code when USMP is used to make the tests more strict. This change also removes the confusing "run_model" which has semantics identitical to "__tvm_main__" in TIR. Change-Id: I5202d9cc7c6a8c00c73791b82df062a8e13dd224
1 parent 6c6e873 commit a7f71e1

File tree

14 files changed

+213
-88
lines changed

14 files changed

+213
-88
lines changed

apps/microtvm/zephyr_cmsisnn/src/main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ extern float output_storage[12];
3434

3535
extern const size_t output_len;
3636

37-
static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE + 512];
37+
static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE];
3838
tvm_workspace_t app_workspace;
3939

4040
void TVMLogf(const char* msg, ...) {

include/tvm/runtime/module.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ constexpr const char* tvm_module_main = "__tvm_main__";
235235
constexpr const char* tvm_param_prefix = "__tvm_param__";
236236
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
237237
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
238-
/*! \brief The main AOT executor function generated from TIR */
239-
constexpr const char* tvm_run_func_suffix = "run_model";
240238
/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */
241239
constexpr const char* tvm_entrypoint_suffix = "run";
242240
} // namespace symbol

src/relay/backend/aot_executor_codegen.cc

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
658658

659659
// Define the PrimFunc attributes
660660
Map<String, ObjectRef> dict_attrs;
661-
String run_func_name =
662-
runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix);
661+
String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main);
663662
dict_attrs.Set("global_symbol", run_func_name);
664663
dict_attrs.Set("runner_function", Bool(true));
665664
dict_attrs.Set(tvm::attr::kTarget, target_host_);
@@ -702,6 +701,35 @@ class AOTExecutorCodegen : public MixedModeVisitor {
702701
}
703702
}
704703

704+
/*!
705+
* brief Calculate workspace sizes for PrimFuncs in the IRModule
706+
*/
707+
Map<String, FunctionInfo> CalculateWorkspaceSizes(
708+
const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
709+
Executor executor_config = lowered_mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
710+
Integer workspace_byte_alignment =
711+
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
712+
Map<String, FunctionInfo> updated_function_metadata;
713+
for (const auto& kv : lowered_mod->functions) {
714+
GlobalVar global_var = kv.first;
715+
BaseFunc base_func = kv.second;
716+
if (base_func->IsInstance<tir::PrimFuncNode>()) {
717+
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
718+
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
719+
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
720+
if (function_metadata.count(global_var->name_hint)) {
721+
updated_function_metadata.Set(global_var->name_hint,
722+
function_metadata[global_var->name_hint]);
723+
updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws);
724+
} else {
725+
FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
726+
updated_function_metadata.Set(global_var->name_hint, finfo);
727+
}
728+
}
729+
}
730+
return updated_function_metadata;
731+
}
732+
705733
/*!
706734
* brief Run USMP to plan memory for lowered IRModule
707735
*/
@@ -710,17 +738,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
710738
Integer workspace_byte_alignment =
711739
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
712740
IRModule lowered_mod = mod->ShallowCopy();
741+
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
713742
lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
714-
// Update workspace size based on the pool allocations.
715-
for (const auto& kv : function_metadata_) {
716-
if (lowered_mod->ContainGlobalVar(kv.first) &&
717-
lowered_mod->Lookup(kv.first)->IsInstance<tir::PrimFuncNode>()) {
718-
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(lowered_mod->Lookup(kv.first));
719-
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
720-
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
721-
kv.second->workspace_sizes.Set(tgt, ws);
722-
}
723-
}
724743
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
725744
lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
726745
backend::FunctionInfo main_func_info =
@@ -752,17 +771,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
752771
Integer workspace_byte_alignment =
753772
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
754773
IRModule lowered_mod = mod->ShallowCopy();
774+
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
755775
// Running StorageRewrite just on the main function
756776
tir::PrimFunc tir_main_func =
757-
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
777+
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
758778
IRModule main_func_mod;
759-
main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix),
779+
main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
760780
tir_main_func);
761781
main_func_mod = tir::transform::StorageRewrite()(main_func_mod);
762-
lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix),
763-
main_func_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
782+
lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
783+
main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
764784
tir_main_func =
765-
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
785+
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
766786
// Use the PrimFunc to calculate the workspace required to service the allocates
767787
Integer main_workspace_size_bytes =
768788
CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment);
@@ -920,7 +940,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
920940
// function and replacing it with its TIR version. We should try to make this a Pass.
921941
lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
922942
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
923-
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
943+
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), prim_func);
924944
// Parallel for loops are not supported in AoT codegen.
925945
lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);
926946

@@ -960,7 +980,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
960980
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
961981
std::vector<tir::Var> pool_vars;
962982
tir::PrimFunc tir_main_func =
963-
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
983+
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
964984
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
965985
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
966986
if (allocated_pool_infos) {

src/target/source/source_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
474474
}
475475

476476
void GenerateAOTDescriptor() {
477-
const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix;
477+
const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_module_main;
478478
const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix;
479479
const std::string run_func_mangled =
480480
runtime::get_name_mangled(metadata_->mod_name, run_func_suffix);

src/tir/usmp/transform/assign_pool_info.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class PoolInfoAssigner : public StmtExprMutator {
4242
public:
4343
explicit PoolInfoAssigner(const IRModule& module) {
4444
PrimFunc main_func =
45-
Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
45+
Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
4646
ICHECK(main_func.defined()) << "main function is not in the module";
4747
Optional<Target> target_host = main_func->GetAttr<Target>(tvm::attr::kTarget);
4848
ICHECK(target_host) << "main function does not have a target attr";
@@ -79,7 +79,7 @@ class PoolInfoAssigner : public StmtExprMutator {
7979
PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) {
8080
Map<Target, String> target_access;
8181
tir::PrimFunc tir_main_func =
82-
Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
82+
Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
8383
Target target_host = tir_main_func->GetAttr<Target>(tvm::attr::kTarget).value();
8484
for (const auto& kv : module->functions) {
8585
BaseFunc func = kv.second;

src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) {
331331
}
332332

333333
IRModule PoolAllocationToOffsetConverter::operator()() {
334-
GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix);
334+
GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
335335
PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
336336
ScopeInfo si = UpdateFunctionScopeInfo(main_func);
337337
this->scope_stack.push(si);

src/tir/usmp/unified_static_memory_planner.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
5151

5252
IRModule PlanMemory(const IRModule& mod, String algo) {
5353
VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
54-
PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
54+
PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
5555
BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, mod);
5656
Array<BufferInfo> buffer_info_arr =
5757
CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts);
@@ -63,7 +63,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) {
6363
buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
6464
IRModule ret = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(mod);
6565
tir::PrimFunc tir_main_func =
66-
Downcast<tir::PrimFunc>(ret->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
66+
Downcast<tir::PrimFunc>(ret->Lookup(::tvm::runtime::symbol::tvm_module_main));
6767
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
6868
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
6969
if (allocated_pool_infos) {

src/tir/usmp/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class ModuleWorkspaceSizeCalculator : public StmtExprVisitor {
181181
for (const auto& gv_func : mod_->functions) {
182182
functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
183183
}
184-
main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
184+
main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
185185
ICHECK(main_func_.defined()) << "main function is not in the module";
186186
Optional<Target> target_host = main_func_->GetAttr<Target>(tvm::attr::kTarget);
187187
ICHECK(target_host) << "main function does not have a target attr";

tests/python/contrib/test_ethosu/infra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,13 @@ def build_source(
242242
def verify_source(
243243
models: List[AOTCompiledTestModel],
244244
accel="ethos-u55-256",
245+
enable_usmp=True,
245246
):
246247
"""
247248
This method verifies the generated source from an NPU module by building it and running on an FVP.
248249
"""
249250
interface_api = "c"
250-
test_runner = create_test_runner(accel)
251+
test_runner = create_test_runner(accel, enable_usmp)
251252
run_and_check(
252253
models,
253254
test_runner,

tests/python/contrib/test_ethosu/test_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_forward_mobilenet_v1(accel_type, enable_usmp):
7171
compiled_models = infra.build_source(
7272
mod, input_data, output_data, accel_type, output_tolerance=10, enable_usmp=enable_usmp
7373
)
74-
infra.verify_source(compiled_models, accel_type)
74+
infra.verify_source(compiled_models, accel_type, enable_usmp=enable_usmp)
7575

7676

7777
if __name__ == "__main__":

0 commit comments

Comments
 (0)