@@ -784,13 +784,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
784784 * brief Create tir::Var for input/output while updating
785785 * the buffer_maps.
786786 */
787- void CreateIOVar (const Expr& expr, std::string name) {
787+ void CreateIOVar (const Expr& expr, const std::string& original_name,
788+ bool use_unique_name = true ) {
788789 if (expr->IsInstance <TupleNode>()) {
789790 Tuple tuple = Downcast<Tuple>(expr);
790791 for (unsigned i = 0 ; i < tuple->fields .size (); i++) {
791- CreateIOVar (tuple->fields [i], name + std::to_string (i) + " _ " );
792+ CreateIOVar (tuple->fields [i], original_name );
792793 }
793794 } else {
795+ std::string name = original_name;
796+ if (use_unique_name) {
797+ name = GetUniqueIOVarName (original_name);
798+ }
794799 tir::Var var = tir::Var (name, DataType::Handle ());
795800 main_signature_.push_back (var);
796801 auto tensor_type = expr->checked_type ().as <TensorTypeNode>();
@@ -804,6 +809,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
804809 }
805810 }
806811
812+ /* !
813+ * brief Create a unique name for I/O Var
814+ */
815+ std::string GetUniqueIOVarName (std::string name) {
816+ if (io_var_names_.find (name) == io_var_names_.end ()) {
817+ io_var_names_[name] = 1 ;
818+ return name;
819+ } else {
820+ io_var_names_[name] = io_var_names_[name] + 1 ;
821+ return name + std::to_string (io_var_names_[name]);
822+ }
823+ }
824+
807825 /* !
808826 * brief Calculate workspace sizes for PrimFuncs in the IRModule
809827 */
@@ -945,6 +963,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
945963 std::vector<tir::Stmt> stmts_;
946964 /* ! \brief the list of return sids (note that the function might return more then one output */
947965 std::vector<int > return_sid_;
966+ /* ! \brief This is per IO var name counter to aid the generating unique names */
967+ std::unordered_map<std::string, int > io_var_names_;
948968
949969 public:
950970 AOTExecutorCodegen (runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
@@ -1032,7 +1052,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10321052 for (auto input : lowered_main_func->params ) {
10331053 input_vars_.push_back (input);
10341054 std::string input_name = SanitizeName (input->name_hint ());
1035- CreateIOVar (input, input_name);
1055+ // We dont want the compiler changing input names in the
1056+ // event of a sanitization collision. Therefore, enforcing
1057+ // the var created to use the input_name strictly.
1058+ CreateIOVar (input, input_name, /* use_unique_name = */ false );
10361059 }
10371060
10381061 // Define the storage allocator ids
@@ -1052,7 +1075,22 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10521075 // Retrieve the return sids
10531076 return_sid_ = final_aot_allocator.GetReturnIds ();
10541077 // Insert outputs to main func signature
1055- CreateIOVar (lowered_main_func->body , " output" );
1078+ // If output tensor names were provided use them
1079+ if (auto opt = func->GetAttr <Array<String>>(" output_tensor_names" )) {
1080+ Array<String> output_tensor_names = opt.value ();
1081+ if (lowered_main_func->body ->IsInstance <TupleNode>()) {
1082+ Tuple output_tuple = Downcast<Tuple>(lowered_main_func->body );
1083+ for (unsigned i = 0 ; i < output_tuple->fields .size (); i++) {
1084+ CreateIOVar (output_tuple->fields [i], output_tensor_names[i]);
1085+ }
1086+ } else {
1087+ CreateIOVar (lowered_main_func->body , output_tensor_names[0 ]);
1088+ }
1089+ } else {
1090+ // If output tensor names are not provided we will generate output(x)
1091+ // where x is a counter to create unique names.
1092+ CreateIOVar (lowered_main_func->body , " output" );
1093+ }
10561094
10571095 CollectDeviceVariables (lowered_mod->GetAttr <Map<GlobalVar, String>>(" device_contexts" ).value ());
10581096 VisitExpr (lowered_main_func->body );
@@ -1071,8 +1109,39 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10711109 // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
10721110 // function and replacing it with its TIR version. We should try to make this a Pass.
10731111 lowered_mod->Remove (lowered_mod->GetGlobalVar (" main" ));
1074- auto prim_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
1075- lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_module_main), prim_func);
1112+ auto tir_main_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
1113+ // Extract additional information around main TIR PrimFunc arguments
1114+ Array<String> devices = ListDevices ();
1115+ Array<tir::Var> inputs =
1116+ Array<tir::Var>(tir_main_func->params .begin (), tir_main_func->params .begin () +
1117+ tir_main_func->params .size () -
1118+ return_sid_.size () - devices.size ());
1119+ Array<TensorType> input_tensor_types;
1120+ for (auto i : inputs) {
1121+ input_tensor_types.push_back (io_tensor_types_[i]);
1122+ }
1123+
1124+ std::vector<String> output_var_names;
1125+ if (auto opt = func->GetAttr <Array<String>>(" output_tensor_names" )) {
1126+ Array<String> output_tensor_names = opt.value ();
1127+ for (size_t i = 0 ; i < output_tensor_names.size (); ++i) {
1128+ output_var_names.push_back (output_tensor_names[i]);
1129+ }
1130+ }
1131+
1132+ // If output names have not been specified then generate default output names
1133+ if (output_var_names.size () == 0 ) {
1134+ if (return_sid_.size () == 1 ) {
1135+ output_var_names.push_back (String (" output" ));
1136+ } else {
1137+ for (size_t i = 0 ; i < return_sid_.size (); ++i) {
1138+ output_var_names.push_back (String (" output" + std::to_string (i)));
1139+ }
1140+ }
1141+ }
1142+
1143+ Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes ()};
1144+ lowered_mod->Update (GlobalVar (::tvm::runtime::symbol::tvm_module_main), tir_main_func);
10761145 // Parallel for loops are not supported in AoT codegen.
10771146 lowered_mod = tir::transform::ConvertForLoopsToSerial ()(lowered_mod);
10781147
@@ -1109,9 +1178,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11091178
11101179 ret.external_mods = external_modules.value ();
11111180
1181+ // Extract USMP metadata to pass onto metadata sources
11121182 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
11131183 std::vector<tir::Var> pool_vars;
1114- tir::PrimFunc tir_main_func =
1184+ tir_main_func =
11151185 Downcast<tir::PrimFunc>(lowered_mod->Lookup (::tvm::runtime::symbol::tvm_module_main));
11161186 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
11171187 tir_main_func->GetAttr <Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs );
@@ -1122,41 +1192,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11221192 pool_var_info.Set (tir_main_func->params [pool_var_index], allocated_pool_info);
11231193 }
11241194 }
1125- Array<String> devices = ListDevices ();
1126- Array<tir::Var> inputs =
1127- Array<tir::Var>(tir_main_func->params .begin (),
1128- tir_main_func->params .begin () + tir_main_func->params .size () -
1129- return_sid_.size () - pool_vars.size () - devices.size ());
1130-
1131- Array<TensorType> input_tensor_types;
1132- for (auto i : inputs) {
1133- input_tensor_types.push_back (io_tensor_types_[i]);
1134- }
1195+ Map<String, tir::usmp::PoolAllocation> io_pool_allocations =
1196+ lowered_mod
1197+ ->GetAttr <Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations )
1198+ .value_or ({});
11351199
1136- std::vector<String> output_var_names;
1137- if (auto opt = func->GetAttr <Array<String>>(" output_tensor_names" )) {
1138- Array<String> output_tensor_names = opt.value ();
1139- for (size_t i = 0 ; i < output_tensor_names.size (); ++i) {
1140- output_var_names.push_back (output_tensor_names[i]);
1141- }
1142- }
1143-
1144- // If output names have not been specified then generate default output names
1145- if (output_var_names.size () == 0 ) {
1146- if (return_sid_.size () == 1 ) {
1147- output_var_names.push_back (String (" output" ));
1148- } else {
1149- for (size_t i = 0 ; i < return_sid_.size (); ++i) {
1150- output_var_names.push_back (String (" output" + std::to_string (i)));
1151- }
1152- }
1153- }
1154-
1155- Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes ()};
1200+ ret.metadata =
1201+ ExecutorCodegenMetadata (inputs, input_tensor_types, output_var_names, output_tensor_types,
1202+ pool_vars, devices, runtime::kTvmExecutorAot , mod_name,
1203+ interface_api, unpacked_api, pool_var_info, io_pool_allocations);
11561204
1157- ret.metadata = ExecutorCodegenMetadata (
1158- inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices,
1159- runtime::kTvmExecutorAot , mod_name, interface_api, unpacked_api, pool_var_info);
11601205 return ret;
11611206 }
11621207
0 commit comments