From 6055769bae5d661dbe6dbe1f27b30169950187eb Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 19 May 2020 04:08:21 -0700 Subject: [PATCH] Fix Signed-off-by: Serge Panev --- src/c_api/c_api_symbolic.cc | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 31816a71987e..42d0da5c5190 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1421,15 +1421,27 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = indexed_graph.input_nodes().at(i); if (mutable_nodes.count(nid)) { - auto name = input_names[i]; CHECK_LT(aux_top, aux_len) - << "Cannot find aux '" << name << "' in provided aux to optimize_for"; + << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; if (in_aux_ptr[aux_top] != nullptr) { const auto &in_arg = *(in_aux_ptr[aux_top]); arg_shapes[i] = in_arg.shape(); arg_dtypes[i] = in_arg.dtype(); arg_stypes[i] = in_arg.storage_type(); + } + aux_top++; + } else { + auto name = input_names[i]; + CHECK_LT(args_top, args_len) + << "Cannot find arg '" << name << "' in provided args to optimize_for"; + if (in_args_ptr[args_top] != nullptr) { + const auto &in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); } else { + // input_names[i] is not in args but can be in the optional + // shape/type/stype attribute dicts. auto it_shape = input_shape_map.find(name); if (it_shape != input_shape_map.end()) { arg_shapes[i] = it_shape->second; @@ -1443,16 +1455,6 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, arg_stypes[i] = it_type->second; } } - aux_top++; - } else { - CHECK_LT(args_top, args_len) - << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; - if (in_args_ptr[args_top] != nullptr) { - const auto &in_arg = *(in_args_ptr[args_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } args_top++; } }