diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 31816a71987e..81c833a7c905 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1429,6 +1429,16 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, arg_shapes[i] = in_arg.shape(); arg_dtypes[i] = in_arg.dtype(); arg_stypes[i] = in_arg.storage_type(); + } + aux_top++; + } else { + CHECK_LT(args_top, args_len) + << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; + if (in_args_ptr[args_top] != nullptr) { + const auto &in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); } else { auto it_shape = input_shape_map.find(name); if (it_shape != input_shape_map.end()) { @@ -1443,16 +1453,6 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, arg_stypes[i] = it_type->second; } } - aux_top++; - } else { - CHECK_LT(args_top, args_len) - << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for"; - if (in_args_ptr[args_top] != nullptr) { - const auto &in_arg = *(in_args_ptr[args_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } args_top++; } }