diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 2ca92817d423..d99d69c7163f 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1924,6 +1924,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::unordered_map* shared_buffer, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); + bool init = false; if (!exec->subgraph_property().empty()) { static bool verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", false); const auto& backend_name = exec->subgraph_property(); @@ -1937,8 +1938,8 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::vector tmp_in_args; std::vector tmp_arg_grads; std::vector tmp_aux_states; - const std::vector arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs); - const std::vector aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + const auto arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const auto aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates); symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose); @@ -1946,10 +1947,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads, &tmp_aux_states, shared_buffer, shared_exec); - const std::vector new_arg_names = - symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs); - const std::vector new_aux_names = - symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + init = true; + const auto new_arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const auto new_aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates); std::unordered_map new_arg_names_idx_map; std::unordered_map new_aux_names_idx_map; for (size_t i = 0; i != new_arg_names.size(); ++i) { @@ -1979,7 +1979,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, aux_states->emplace_back(std::move(tmp_aux_states[it->second])); } } - } else { + } + if (!init) { + // init without subgraph exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec);