Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed May 19, 2020
1 parent 87dbfbf commit 6055769
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1421,15 +1421,27 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
auto name = input_names[i];
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << name << "' in provided aux to optimize_for";
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
Expand All @@ -1443,16 +1455,6 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
arg_stypes[i] = it_type->second;
}
}
aux_top++;
} else {
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
args_top++;
}
}
Expand Down

0 comments on commit 6055769

Please sign in to comment.