Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rename APIs in graph executor
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 4, 2017
1 parent f67d45f commit f9d2a9f
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 155 deletions.
200 changes: 100 additions & 100 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ inline ValueType get_node_attr(
}
}

nnvm::Graph GraphExecutor::InitFullGraph(
nnvm::Graph GraphExecutor::InitFullGraphV1(
nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& arg_grad_store) {
Expand Down Expand Up @@ -242,8 +242,8 @@ nnvm::Graph GraphExecutor::InitFullGraph(
* \brief Create the graph for backward pass.
* This is triggered by both simple_bind and bind flows.
*/
nnvm::Graph GraphExecutor::InitFullGraph2(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
using nnvm::NodePtr;
using nnvm::NodeEntry;
// initial information
Expand Down Expand Up @@ -302,14 +302,14 @@ nnvm::Graph GraphExecutor::InitFullGraph2(nnvm::Symbol symbol,
}

// pass to assign context to the graph
Graph AssignContext(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<std::pair<OpReqType, NDArray> >& grad_store,
const std::vector<NDArray>& aux_states,
size_t num_forward_inputs,
size_t num_forward_outputs) {
Graph AssignContextV1(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<std::pair<OpReqType, NDArray> >& grad_store,
const std::vector<NDArray>& aux_states,
size_t num_forward_inputs,
size_t num_forward_outputs) {
const auto& idx = g.indexed_graph();
const auto& mutable_nodes = idx.mutable_input_nodes();
// default use default context.
Expand Down Expand Up @@ -397,14 +397,14 @@ Graph AssignContext(Graph g,
* \brief Assign context to the graph.
* This is triggered by both simple_bind and bind flows.
*/
Graph AssignContext2(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
size_t num_forward_inputs,
size_t num_forward_outputs) {
Graph AssignContext(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
size_t num_forward_inputs,
size_t num_forward_outputs) {
const auto& idx = g.indexed_graph();
const auto& mutable_nodes = idx.mutable_input_nodes();
// default use default context.
Expand Down Expand Up @@ -503,14 +503,14 @@ Graph AssignContext2(Graph g,
* input arguments and gradients are provided by users. This initializer
* uses the user provided NDArrays to populate data entries of the graph.
*/
void GraphExecutor::Init2(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_types,
const std::vector<NDArray>& aux_states,
Executor* shared_exec) {
void GraphExecutor::Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_types,
const std::vector<NDArray>& aux_states,
Executor* shared_exec) {
// create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes
auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); };
auto get_ctx2 = [default_ctx](const NDArray& nd) {
Expand All @@ -524,8 +524,8 @@ void GraphExecutor::Init2(nnvm::Symbol symbol,
std::vector<Context> aux_state_ctxes(aux_states.size());
std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1);

nnvm::Graph g = InitGraph2(symbol, default_ctx, ctx_map, in_arg_ctxes,
arg_grad_ctxes, aux_state_ctxes, grad_req_types);
nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
arg_grad_ctxes, aux_state_ctxes, grad_req_types);

// create arg_shapes and arg_dtypes for shape and type inferences
const auto& idx = g.indexed_graph();
Expand Down Expand Up @@ -567,18 +567,18 @@ void GraphExecutor::Init2(nnvm::Symbol symbol,
FinishInitGraph(symbol, g, shared_exec);
}

void GraphExecutor::Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
nnvm::Graph g = InitGraph(symbol, default_ctx,
ctx_map, in_args, arg_grad_store,
grad_req_type, aux_states, feed_dict);
void GraphExecutor::InitV1(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
nnvm::Graph g = InitGraphV1(symbol, default_ctx,
ctx_map, in_args, arg_grad_store,
grad_req_type, aux_states, feed_dict);
g.attrs["saved_opr"] = std::make_shared<nnvm::any>(std::move(saved_opr_));
g = AttachOpExecs(g);
g = AttachOpResources(g);
Expand Down Expand Up @@ -845,26 +845,26 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, nnvm::Graph g, Executor
* NDArrays for in_args, arg_grads, and aux_states for resuing
* already allocated memory.
*/
void GraphExecutor::Init2(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
const std::vector<NDArray>& shared_exec_in_args,
const std::vector<NDArray>& shared_exec_arg_grads,
const std::vector<NDArray>& shared_exec_aux_states,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_data_arrays,
Executor* shared_exec) {
nnvm::Graph g = InitGraph2(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
aux_state_ctxes, grad_req_types);
void GraphExecutor::Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
const std::vector<NDArray>& shared_exec_in_args,
const std::vector<NDArray>& shared_exec_arg_grads,
const std::vector<NDArray>& shared_exec_aux_states,
std::vector<NDArray>* in_arg_vec,
std::vector<NDArray>* arg_grad_vec,
std::vector<NDArray>* aux_state_vec,
std::unordered_map<std::string, NDArray>* shared_data_arrays,
Executor* shared_exec) {
nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
aux_state_ctxes, grad_req_types);

// The following code of shape and dtype inferences and argument
// initialization is for simple_bind only. Regular bind operation
Expand Down Expand Up @@ -915,22 +915,22 @@ void GraphExecutor::Init2(nnvm::Symbol symbol,
FinishInitGraph(symbol, g, shared_exec);
}

Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
Graph GraphExecutor::InitGraphV1(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// setup gradient
nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store);
g = AssignContext(g, default_ctx, ctx_map,
in_args,
grad_store_,
aux_states,
num_forward_inputs_,
num_forward_outputs_);
nnvm::Graph g = InitFullGraphV1(symbol, grad_req_type, arg_grad_store);
g = AssignContextV1(g, default_ctx, ctx_map,
in_args,
grad_store_,
aux_states,
num_forward_inputs_,
num_forward_outputs_);
const auto& idx = g.indexed_graph();
// get number of nodes used in forward pass
num_forward_nodes_ = 0;
Expand Down Expand Up @@ -997,23 +997,23 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
* attributes in the graph, and calculate the number
* of forward nodes.
*/
Graph GraphExecutor::InitGraph2(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::vector<OpReqType>& grad_req_types) {
// setup gradient
nnvm::Graph g = InitFullGraph2(symbol, grad_req_types);
nnvm::Graph g = InitFullGraph(symbol, grad_req_types);

// create "device" and "context" attrs for the graph
g = AssignContext2(g, default_ctx, ctx_map,
in_arg_ctxes,
arg_grad_ctxes,
aux_state_ctxes,
num_forward_inputs_,
num_forward_outputs_);
g = AssignContext(g, default_ctx, ctx_map,
in_arg_ctxes,
arg_grad_ctxes,
aux_state_ctxes,
num_forward_inputs_,
num_forward_outputs_);

const auto& idx = g.indexed_graph();
// get number of nodes used in forward pass
Expand Down Expand Up @@ -1508,13 +1508,13 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>* shared_data_arrays,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
exec->Init2(symbol, default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map,
grad_req_types, param_names, shared_exec_in_args,
shared_exec_arg_grads, shared_exec_aux_states,
in_args, arg_grads, aux_states,
shared_data_arrays, shared_exec);
exec->Init(symbol, default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map,
grad_req_types, param_names, shared_exec_in_args,
shared_exec_arg_grads, shared_exec_aux_states,
in_args, arg_grads, aux_states,
shared_data_arrays, shared_exec);
return exec;
}

Expand All @@ -1527,9 +1527,9 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
const std::vector<NDArray> &aux_states,
Executor* shared_exec) {
auto exec = new exec::GraphExecutor();
exec->Init2(symbol, default_ctx, group2ctx,
in_args, arg_grad_store, grad_req_type, aux_states,
reinterpret_cast<Executor*>(shared_exec));
exec->Init(symbol, default_ctx, group2ctx,
in_args, arg_grad_store, grad_req_type, aux_states,
reinterpret_cast<Executor*>(shared_exec));
return exec;
}
} // namespace mxnet
Loading

0 comments on commit f9d2a9f

Please sign in to comment.