Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 4, 2017
1 parent 491700e commit 8bf8f28
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
20 changes: 15 additions & 5 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_types,
const std::vector<NDArray>& aux_states,
Executor* shared_exec) {
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes
auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); };
auto get_ctx2 = [default_ctx](const NDArray& nd) {
Expand Down Expand Up @@ -564,7 +565,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
// Initialize the rest attributes of the graph.
// This function can be called by regular bind
// operation flow as well.
FinishInitGraph(symbol, g, shared_exec);
FinishInitGraph(symbol, g, shared_exec, feed_dict);
}

void GraphExecutor::InitV1(nnvm::Symbol symbol,
Expand Down Expand Up @@ -782,7 +783,10 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
* \brief Finish graph initialization after shape and dtype inferences.
* This function is used by both simple_bind and bind flows.
*/
void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor* shared_exec) {
void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
nnvm::Graph g,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
const auto& idx = g.indexed_graph();
for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second;
Expand All @@ -796,6 +800,11 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor
for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID;
}
for (const auto& kv : feed_dict) {
uint32_t eid = idx.entry_id(kv.first);
data_entry_[eid] = kv.second;
arg_storage_id[eid] = kExternalStorageID;
}
g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(arg_storage_id));
g = nnvm::ApplyPass(g, "PlanMemory");
}
Expand Down Expand Up @@ -862,7 +871,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_data_arrays,
Executor* shared_exec) {
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
aux_state_ctxes, grad_req_types);

Expand Down Expand Up @@ -912,7 +922,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
// Initialize the rest attributes of the graph.
// This function can be called by regular bind
// operation flow as well.
FinishInitGraph(symbol, g, shared_exec);
FinishInitGraph(symbol, g, shared_exec, feed_dict);
}

Graph GraphExecutor::InitGraphV1(nnvm::Symbol symbol,
Expand Down
21 changes: 16 additions & 5 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ class GraphExecutor : public Executor {
void SetMonitorCallback(const MonitorCallback& callback) override;
// Initialize the rest of attributes
// after setting up arguments.
void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor* shared_exec = nullptr);
void FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g,
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());

// initialized the executor
void InitV1(nnvm::Symbol symbol,
Expand All @@ -63,7 +66,9 @@ class GraphExecutor : public Executor {
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
Executor* shared_exec = nullptr);
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// initialize executor for bind
void Init(nnvm::Symbol symbol,
const Context& default_ctx,
Expand All @@ -72,7 +77,9 @@ class GraphExecutor : public Executor {
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_types,
const std::vector<NDArray>& aux_states,
Executor* shared_exec = nullptr);
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// initialize executor for simple bind
void Init(nnvm::Symbol symbol,
const Context& default_ctx,
Expand All @@ -91,7 +98,9 @@ class GraphExecutor : public Executor {
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_data_arrays = nullptr,
Executor* shared_exec = nullptr);
Executor* shared_exec = nullptr,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());

protected:
// Information about operational node
Expand Down Expand Up @@ -160,7 +169,9 @@ class GraphExecutor : public Executor {
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states);
const std::vector<NDArray>& aux_states,
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
// internal initialization of the graph for simple bind
Graph InitGraph(nnvm::Symbol symbol,
const Context& default_ctx,
Expand Down

0 comments on commit 8bf8f28

Please sign in to comment.