@@ -658,8 +658,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
658658
659659 // Define the PrimFunc attributes
660660 Map<String, ObjectRef> dict_attrs;
661- String run_func_name =
662- runtime::get_name_mangled (mod_name, runtime::symbol::tvm_run_func_suffix);
661+ String run_func_name = runtime::get_name_mangled (mod_name, runtime::symbol::tvm_module_main);
663662 dict_attrs.Set (" global_symbol" , run_func_name);
664663 dict_attrs.Set (" runner_function" , Bool (true ));
665664 dict_attrs.Set (tvm::attr::kTarget , target_host_);
@@ -702,6 +701,35 @@ class AOTExecutorCodegen : public MixedModeVisitor {
702701 }
703702 }
704703
704+ /* !
705+ * brief Calculate workspace sizes for PrimFuncs in the IRModule
706+ */
707+ Map<String, FunctionInfo> CalculateWorkspaceSizes (
708+ const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
709+ Executor executor_config = lowered_mod->GetAttr <Executor>(tvm::attr::kExecutor ).value ();
710+ Integer workspace_byte_alignment =
711+ executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
712+ Map<String, FunctionInfo> updated_function_metadata;
713+ for (const auto & kv : lowered_mod->functions ) {
714+ GlobalVar global_var = kv.first ;
715+ BaseFunc base_func = kv.second ;
716+ if (base_func->IsInstance <tir::PrimFuncNode>()) {
717+ tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
718+ Target tgt = pfunc->GetAttr <Target>(tvm::attr::kTarget ).value ();
719+ const auto & ws = CalculateWorkspaceBytes (pfunc, workspace_byte_alignment);
720+ if (function_metadata.count (global_var->name_hint )) {
721+ updated_function_metadata.Set (global_var->name_hint ,
722+ function_metadata[global_var->name_hint ]);
723+ updated_function_metadata[global_var->name_hint ]->workspace_sizes .Set (tgt, ws);
724+ } else {
725+ FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
726+ updated_function_metadata.Set (global_var->name_hint , finfo);
727+ }
728+ }
729+ }
730+ return updated_function_metadata;
731+ }
732+
705733 /* !
706734 * brief Run USMP to plan memory for lowered IRModule
707735 */
@@ -710,17 +738,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
710738 Integer workspace_byte_alignment =
711739 executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
712740 IRModule lowered_mod = mod->ShallowCopy ();
741+ function_metadata_ = CalculateWorkspaceSizes (lowered_mod, function_metadata_);
713742 lowered_mod = tir::transform::UnifiedStaticMemoryPlanner ()(lowered_mod);
714- // Update workspace size based on the pool allocations.
715- for (const auto & kv : function_metadata_) {
716- if (lowered_mod->ContainGlobalVar (kv.first ) &&
717- lowered_mod->Lookup (kv.first )->IsInstance <tir::PrimFuncNode>()) {
718- tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(lowered_mod->Lookup (kv.first ));
719- Target tgt = pfunc->GetAttr <Target>(tvm::attr::kTarget ).value ();
720- const auto & ws = CalculateWorkspaceBytes (pfunc, workspace_byte_alignment);
721- kv.second ->workspace_sizes .Set (tgt, ws);
722- }
723- }
724743 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
725744 lowered_mod->GetAttr <Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs );
726745 backend::FunctionInfo main_func_info =
@@ -752,17 +771,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
752771 Integer workspace_byte_alignment =
753772 executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
754773 IRModule lowered_mod = mod->ShallowCopy ();
774+ function_metadata_ = CalculateWorkspaceSizes (lowered_mod, function_metadata_);
755775 // Running StorageRewrite just on the main function
756776 tir::PrimFunc tir_main_func =
757- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
777+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
758778 IRModule main_func_mod;
759- main_func_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ),
779+ main_func_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_module_main ),
760780 tir_main_func);
761781 main_func_mod = tir::transform::StorageRewrite ()(main_func_mod);
762- lowered_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ),
763- main_func_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
782+ lowered_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_module_main ),
783+ main_func_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
764784 tir_main_func =
765- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
785+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
766786 // Use the PrimFunc to calculate the workspace required to service the allocates
767787 Integer main_workspace_size_bytes =
768788 CalculateWorkspaceBytes (tir_main_func, workspace_byte_alignment);
@@ -920,7 +940,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
920940 // function and replacing it with its TIR version. We should try to make this a Pass.
921941 lowered_mod->Remove (lowered_mod->GetGlobalVar (" main" ));
922942 auto prim_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
923- lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ), prim_func);
943+ lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_module_main ), prim_func);
924944 // Parallel for loops are not supported in AoT codegen.
925945 lowered_mod = tir::transform::ConvertForLoopsToSerial ()(lowered_mod);
926946
@@ -960,7 +980,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
960980 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
961981 std::vector<tir::Var> pool_vars;
962982 tir::PrimFunc tir_main_func =
963- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
983+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
964984 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
965985 tir_main_func->GetAttr <Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs );
966986 if (allocated_pool_infos) {
0 commit comments