@@ -642,8 +642,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
642642
643643 // Define the PrimFunc attributes
644644 Map<String, ObjectRef> dict_attrs;
645- String run_func_name =
646- runtime::get_name_mangled (mod_name, runtime::symbol::tvm_run_func_suffix);
645+ String run_func_name = runtime::get_name_mangled (mod_name, runtime::symbol::tvm_module_main);
647646 dict_attrs.Set (" global_symbol" , run_func_name);
648647 dict_attrs.Set (" runner_function" , Bool (true ));
649648 dict_attrs.Set (tvm::attr::kTarget , target_host_);
@@ -686,6 +685,35 @@ class AOTExecutorCodegen : public MixedModeVisitor {
686685 }
687686 }
688687
688+ /* !
689+ * brief Calculate workspace sizes for PrimFuncs in the IRModule
690+ */
691+ Map<String, FunctionInfo> CalculateWorkspaceSizes (
692+ const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
693+ Executor executor_config = lowered_mod->GetAttr <Executor>(tvm::attr::kExecutor ).value ();
694+ Integer workspace_byte_alignment =
695+ executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
696+ Map<String, FunctionInfo> updated_function_metadata;
697+ for (const auto & kv : lowered_mod->functions ) {
698+ GlobalVar global_var = kv.first ;
699+ BaseFunc base_func = kv.second ;
700+ if (base_func->IsInstance <tir::PrimFuncNode>()) {
701+ tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
702+ Target tgt = pfunc->GetAttr <Target>(tvm::attr::kTarget ).value ();
703+ const auto & ws = CalculateWorkspaceBytes (pfunc, workspace_byte_alignment);
704+ if (function_metadata.count (global_var->name_hint )) {
705+ updated_function_metadata.Set (global_var->name_hint ,
706+ function_metadata[global_var->name_hint ]);
707+ updated_function_metadata[global_var->name_hint ]->workspace_sizes .Set (tgt, ws);
708+ } else {
709+ FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
710+ updated_function_metadata.Set (global_var->name_hint , finfo);
711+ }
712+ }
713+ }
714+ return updated_function_metadata;
715+ }
716+
689717 /* !
690718 * brief Run USMP to plan memory for lowered IRModule
691719 */
@@ -694,17 +722,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
694722 Integer workspace_byte_alignment =
695723 executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
696724 IRModule lowered_mod = mod->ShallowCopy ();
725+ function_metadata_ = CalculateWorkspaceSizes (lowered_mod, function_metadata_);
697726 lowered_mod = tir::transform::UnifiedStaticMemoryPlanner ()(lowered_mod);
698- // Update workspace size based on the pool allocations.
699- for (const auto & kv : function_metadata_) {
700- if (lowered_mod->ContainGlobalVar (kv.first ) &&
701- lowered_mod->Lookup (kv.first )->IsInstance <tir::PrimFuncNode>()) {
702- tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(lowered_mod->Lookup (kv.first ));
703- Target tgt = pfunc->GetAttr <Target>(tvm::attr::kTarget ).value ();
704- const auto & ws = CalculateWorkspaceBytes (pfunc, workspace_byte_alignment);
705- kv.second ->workspace_sizes .Set (tgt, ws);
706- }
707- }
708727 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
709728 lowered_mod->GetAttr <Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs );
710729 backend::FunctionInfo main_func_info =
@@ -736,17 +755,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
736755 Integer workspace_byte_alignment =
737756 executor_config->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
738757 IRModule lowered_mod = mod->ShallowCopy ();
758+ function_metadata_ = CalculateWorkspaceSizes (lowered_mod, function_metadata_);
739759 // Running StorageRewrite just on the main function
740760 tir::PrimFunc tir_main_func =
741- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
761+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
742762 IRModule main_func_mod;
743- main_func_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ),
763+ main_func_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_module_main ),
744764 tir_main_func);
745765 main_func_mod = tir::transform::StorageRewrite ()(main_func_mod);
746- lowered_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ),
747- main_func_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
766+ lowered_mod->Update (lowered_mod->GetGlobalVar (::tvm::runtime::symbol::tvm_module_main ),
767+ main_func_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
748768 tir_main_func =
749- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
769+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
750770 // Use the PrimFunc to calculate the workspace required to service the allocates
751771 Integer main_workspace_size_bytes =
752772 CalculateWorkspaceBytes (tir_main_func, workspace_byte_alignment);
@@ -903,7 +923,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
903923 // function and replacing it with its TIR version. We should try to make this a Pass.
904924 lowered_mod->Remove (lowered_mod->GetGlobalVar (" main" ));
905925 auto prim_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
906- lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix ), prim_func);
926+ lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_module_main ), prim_func);
907927 // Parallel for loops are not supported in AoT codegen.
908928 lowered_mod = tir::transform::ConvertForLoopsToSerial ()(lowered_mod);
909929
@@ -943,7 +963,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
943963 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
944964 std::vector<tir::Var> pool_vars;
945965 tir::PrimFunc tir_main_func =
946- Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_run_func_suffix ));
966+ Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main ));
947967 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
948968 tir_main_func->GetAttr <Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs );
949969 if (allocated_pool_infos) {
0 commit comments