@@ -1112,32 +1112,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11121112 auto tir_main_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
11131113 // Extract additional information around main TIR PrimFunc arguments
11141114 Array<String> devices = ListDevices ();
1115- Array<tir::Var> inputs =
1116- Array<tir::Var>(tir_main_func->params .begin (), tir_main_func->params .begin () +
1117- tir_main_func->params .size () -
1118- return_sid_.size () - devices.size ());
1115+ const auto main_func_params_end_iterator =
1116+ tir_main_func->params .begin () + tir_main_func->params .size ();
1117+ const auto outputs_begin_iterator =
1118+ main_func_params_end_iterator - return_sid_.size () - devices.size ();
1119+ Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params .begin (), outputs_begin_iterator);
11191120 Array<TensorType> input_tensor_types;
11201121 for (auto i : inputs) {
11211122 input_tensor_types.push_back (io_tensor_types_[i]);
11221123 }
1123-
1124+ Array<tir::Var> outputs =
1125+ Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size ());
11241126 std::vector<String> output_var_names;
1125- if (auto opt = func->GetAttr <Array<String>>(" output_tensor_names" )) {
1126- Array<String> output_tensor_names = opt.value ();
1127- for (size_t i = 0 ; i < output_tensor_names.size (); ++i) {
1128- output_var_names.push_back (output_tensor_names[i]);
1129- }
1130- }
1131-
1132- // If output names have not been specified then generate default output names
1133- if (output_var_names.size () == 0 ) {
1134- if (return_sid_.size () == 1 ) {
1135- output_var_names.push_back (String (" output" ));
1136- } else {
1137- for (size_t i = 0 ; i < return_sid_.size (); ++i) {
1138- output_var_names.push_back (String (" output" + std::to_string (i)));
1139- }
1140- }
1127+ for (const tir::Var& output : outputs) {
1128+ output_var_names.push_back (output->name_hint );
11411129 }
11421130
11431131 Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes ()};
0 commit comments