Skip to content

Commit 10b6504

Browse files
committed
fixup! [USMP] Adding support for U4 usecase
Change-Id: I78f03d36b12b4a5e8eae8d11701f51019489defc
1 parent 8fa92ac commit 10b6504

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,32 +1112,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11121112
auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
11131113
// Extract additional information around main TIR PrimFunc arguments
11141114
Array<String> devices = ListDevices();
1115-
Array<tir::Var> inputs =
1116-
Array<tir::Var>(tir_main_func->params.begin(), tir_main_func->params.begin() +
1117-
tir_main_func->params.size() -
1118-
return_sid_.size() - devices.size());
1115+
const auto main_func_params_end_iterator =
1116+
tir_main_func->params.begin() + tir_main_func->params.size();
1117+
const auto outputs_begin_iterator =
1118+
main_func_params_end_iterator - return_sid_.size() - devices.size();
1119+
Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator);
11191120
Array<TensorType> input_tensor_types;
11201121
for (auto i : inputs) {
11211122
input_tensor_types.push_back(io_tensor_types_[i]);
11221123
}
1123-
1124+
Array<tir::Var> outputs =
1125+
Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size());
11241126
std::vector<String> output_var_names;
1125-
if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
1126-
Array<String> output_tensor_names = opt.value();
1127-
for (size_t i = 0; i < output_tensor_names.size(); ++i) {
1128-
output_var_names.push_back(output_tensor_names[i]);
1129-
}
1130-
}
1131-
1132-
// If output names have not been specified then generate default output names
1133-
if (output_var_names.size() == 0) {
1134-
if (return_sid_.size() == 1) {
1135-
output_var_names.push_back(String("output"));
1136-
} else {
1137-
for (size_t i = 0; i < return_sid_.size(); ++i) {
1138-
output_var_names.push_back(String("output" + std::to_string(i)));
1139-
}
1140-
}
1127+
for (const tir::Var& output : outputs) {
1128+
output_var_names.push_back(output->name_hint);
11411129
}
11421130

11431131
Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};

0 commit comments

Comments
 (0)