diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index b81b102f3299..0d22d79a4b6d 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -510,7 +510,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& arg_grad_store, const std::vector& grad_req_types, const std::vector& aux_states, - Executor* shared_exec) { + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { // create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); }; auto get_ctx2 = [default_ctx](const NDArray& nd) { @@ -564,7 +565,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // Initialize the rest attributes of the graph. // This function can be called by regular bind // operation flow as well. - FinishInitGraph(symbol, g, shared_exec); + FinishInitGraph(symbol, g, shared_exec, feed_dict); } void GraphExecutor::InitV1(nnvm::Symbol symbol, @@ -782,7 +783,10 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, * \brief Finish graph initialization after shape and dtype inferences. * This function is used by both simple_bind and bind flows. */ -void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor* shared_exec) { +void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, + nnvm::Graph g, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; @@ -796,6 +800,11 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; } + for (const auto& kv : feed_dict) { + uint32_t eid = idx.entry_id(kv.first); + data_entry_[eid] = kv.second; + arg_storage_id[eid] = kExternalStorageID; + } g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); g = nnvm::ApplyPass(g, "PlanMemory"); } @@ -862,7 +871,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, std::vector* arg_grad_vec, std::vector* aux_state_vec, std::unordered_map* shared_data_arrays, - Executor* shared_exec) { + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types); @@ -912,7 +922,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // Initialize the rest attributes of the graph. // This function can be called by regular bind // operation flow as well. - FinishInitGraph(symbol, g, shared_exec); + FinishInitGraph(symbol, g, shared_exec, feed_dict); } Graph GraphExecutor::InitGraphV1(nnvm::Symbol symbol, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index e3c179ed541e..8e3bd0863183 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -53,7 +53,10 @@ class GraphExecutor : public Executor { void SetMonitorCallback(const MonitorCallback& callback) override; // Initialize the rest of attributes // after setting up arguments. - void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor* shared_exec = nullptr); + void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); // initialized the executor void InitV1(nnvm::Symbol symbol, @@ -63,7 +66,9 @@ class GraphExecutor : public Executor { const std::vector& arg_grad_store, const std::vector& grad_req_type, const std::vector& aux_states, - Executor* shared_exec = nullptr); + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); // initialize executor for bind void Init(nnvm::Symbol symbol, const Context& default_ctx, @@ -72,7 +77,9 @@ class GraphExecutor : public Executor { const std::vector& arg_grad_store, const std::vector& grad_req_types, const std::vector& aux_states, - Executor* shared_exec = nullptr); + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); // initialize executor for simple bind void Init(nnvm::Symbol symbol, const Context& default_ctx, @@ -91,7 +98,9 @@ class GraphExecutor : public Executor { std::vector* arg_grad_vec, std::vector* aux_state_vec, std::unordered_map* shared_data_arrays = nullptr, - Executor* shared_exec = nullptr); + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); protected: // Information about operational node @@ -160,7 +169,9 @@ class GraphExecutor : public Executor { const std::vector& in_args, const std::vector& arg_grad_store, const std::vector& grad_req_type, - const std::vector& aux_states); + const std::vector& aux_states, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, const Context& default_ctx,