Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 4, 2017
1 parent f9d2a9f commit 491700e
Showing 1 changed file with 97 additions and 97 deletions.
194 changes: 97 additions & 97 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ inline ValueType get_node_attr(
}
}

nnvm::Graph GraphExecutor::InitFullGraphV1(
nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& arg_grad_store) {
/*!
* \brief Create the graph for backward pass.
* This is triggered by both simple_bind and bind flows.
*/
nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
using nnvm::NodePtr;
using nnvm::NodeEntry;
// initial information
Expand All @@ -188,7 +190,7 @@ nnvm::Graph GraphExecutor::InitFullGraphV1(
nnvm::Graph g;
g.outputs = symbol.outputs;
bool need_grad = false;
for (OpReqType req : grad_req_type) {
for (OpReqType req : grad_req_types) {
if (req != kNullOp) need_grad = true;
}
if (!need_grad) return g;
Expand All @@ -199,10 +201,8 @@ nnvm::Graph GraphExecutor::InitFullGraphV1(
}
std::vector<NodePtr> args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs);
std::vector<NodeEntry> xs;
for (size_t i = 0; i < grad_req_type.size(); ++i) {
if (grad_req_type[i] != kNullOp) {
grad_store_.emplace_back(
std::make_pair(grad_req_type[i], arg_grad_store[i]));
for (size_t i = 0; i < grad_req_types.size(); ++i) {
if (grad_req_types[i] != kNullOp) {
xs.emplace_back(NodeEntry{args[i], 0, 0});
}
}
Expand Down Expand Up @@ -238,12 +238,10 @@ nnvm::Graph GraphExecutor::InitFullGraphV1(
return g;
}

/*!
* \brief Create the graph for backward pass.
* This is triggered by both simple_bind and bind flows.
*/
nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_types) {
nnvm::Graph GraphExecutor::InitFullGraphV1(
nnvm::Symbol symbol,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& arg_grad_store) {
using nnvm::NodePtr;
using nnvm::NodeEntry;
// initial information
Expand All @@ -253,7 +251,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
nnvm::Graph g;
g.outputs = symbol.outputs;
bool need_grad = false;
for (OpReqType req : grad_req_types) {
for (OpReqType req : grad_req_type) {
if (req != kNullOp) need_grad = true;
}
if (!need_grad) return g;
Expand All @@ -264,8 +262,10 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
}
std::vector<NodePtr> args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs);
std::vector<NodeEntry> xs;
for (size_t i = 0; i < grad_req_types.size(); ++i) {
if (grad_req_types[i] != kNullOp) {
for (size_t i = 0; i < grad_req_type.size(); ++i) {
if (grad_req_type[i] != kNullOp) {
grad_store_.emplace_back(
std::make_pair(grad_req_type[i], arg_grad_store[i]));
xs.emplace_back(NodeEntry{args[i], 0, 0});
}
}
Expand Down Expand Up @@ -301,71 +301,83 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
return g;
}

// pass to assign context to the graph
Graph AssignContextV1(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<std::pair<OpReqType, NDArray> >& grad_store,
const std::vector<NDArray>& aux_states,
size_t num_forward_inputs,
size_t num_forward_outputs) {
/*!
* \brief Assign context to the graph.
* This is triggered by both simple_bind and bind flows.
*/
Graph AssignContext(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
size_t num_forward_inputs,
size_t num_forward_outputs) {
const auto& idx = g.indexed_graph();
const auto& mutable_nodes = idx.mutable_input_nodes();
// default use default context.
if (ctx_map.size() == 0) {
g.attrs["context"] = std::make_shared<nnvm::any>(
ContextVector(idx.num_nodes(), default_ctx));
for (const auto& x : in_args) {
CHECK(x.ctx() == default_ctx)
<< "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx
for (const auto& x : in_arg_ctxes) {
CHECK(x == default_ctx)
<< "Input array is in " << x << " while binding with ctx=" << default_ctx
<< ". All arguments must be in global context (" << default_ctx
<< ") unless group2ctx is specified for cross-device graph.";
}
for (const auto& x : grad_store) {
CHECK(x.second.ctx() == default_ctx)
<< "Gradient array is in " << x.second.ctx() << " while binding with ctx="
for (const auto& x : arg_grad_ctxes) {
CHECK(x == default_ctx)
<< "Gradient array is in " << x << " while binding with ctx="
<< default_ctx << ". All gradients must be in global context (" << default_ctx
<< ") unless group2ctx is specified for cross-device graph.";
}
return g;
}

// otherwise, use context assignment.
std::map<Context, int> ctx2id;
std::vector<Context> ctx_list;
nnvm::DeviceVector device(idx.num_nodes(), -1);
nnvm::DeviceAssignMap device_map;
std::map<Context, int> ctx2id; // map ctx to device id
std::vector<Context> ctx_list; // index is device id
nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id
nnvm::DeviceAssignMap device_map; // map arg name to device id

// loop through the user input ctx_map and
// populate maps and lists
for (auto &kv : ctx_map) {
if (ctx2id.count(kv.second) == 0) {
ctx2id[kv.second] = static_cast<int>(ctx_list.size());
ctx_list.push_back(kv.second);
if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one
ctx2id[kv.second] = static_cast<int>(ctx_list.size()); // assign device id to ctx
ctx_list.push_back(kv.second); // save ctx to the list
}
// assign device id to to the arg name with the corresponding ctx
device_map[kv.first] = ctx2id.at(kv.second);
}

// loop through all the rest of input nodes not specified
// in the ctx_map and populate maps and lists
size_t arg_top = 0, aux_top = 0;
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
Context ctx;
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_states.size());
ctx = aux_states[aux_top].ctx();
if (mutable_nodes.count(nid)) { // aux node is mutable
CHECK_LT(aux_top, aux_state_ctxes.size());
ctx = aux_state_ctxes[aux_top];
++aux_top;
} else {
CHECK_LT(arg_top, in_args.size());
ctx = in_args[arg_top].ctx();
} else { // regular input node is immutable
CHECK_LT(arg_top, in_arg_ctxes.size());
ctx = in_arg_ctxes[arg_top];
++arg_top;
}
if (ctx2id.count(ctx) == 0) {
ctx2id[ctx] = static_cast<int>(ctx_list.size());
ctx_list.push_back(ctx);
if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id
ctx2id[ctx] = static_cast<int>(ctx_list.size()); // assign the current ctx with device id
ctx_list.push_back(ctx); // save the current ctx in the list
}
device[nid] = ctx2id.at(ctx);
device[nid] = ctx2id.at(ctx); // assign device id to the current node
}

// loop through backward input nodes and populate maps and lists
// the backward input nodes is the gradient of the loss wrt the output
for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
const uint32_t nid = idx.outputs()[i].node_id;
Context ctx = grad_store[i - num_forward_outputs].second.ctx();
Context ctx = arg_grad_ctxes[i - num_forward_outputs];
if (ctx2id.count(ctx) == 0) {
ctx2id[ctx] = static_cast<int>(ctx_list.size());
ctx_list.push_back(ctx);
Expand All @@ -377,6 +389,7 @@ Graph AssignContextV1(Graph g,
device[nid] = devid;
}
}

g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
Expand All @@ -393,83 +406,71 @@ Graph AssignContextV1(Graph g,
return g;
}

/*!
* \brief Assign context to the graph.
* This is triggered by both simple_bind and bind flows.
*/
Graph AssignContext(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
size_t num_forward_inputs,
size_t num_forward_outputs) {
// pass to assign context to the graph
Graph AssignContextV1(Graph g,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<std::pair<OpReqType, NDArray> >& grad_store,
const std::vector<NDArray>& aux_states,
size_t num_forward_inputs,
size_t num_forward_outputs) {
const auto& idx = g.indexed_graph();
const auto& mutable_nodes = idx.mutable_input_nodes();
// default use default context.
if (ctx_map.size() == 0) {
g.attrs["context"] = std::make_shared<nnvm::any>(
ContextVector(idx.num_nodes(), default_ctx));
for (const auto& x : in_arg_ctxes) {
CHECK(x == default_ctx)
<< "Input array is in " << x << " while binding with ctx=" << default_ctx
for (const auto& x : in_args) {
CHECK(x.ctx() == default_ctx)
<< "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx
<< ". All arguments must be in global context (" << default_ctx
<< ") unless group2ctx is specified for cross-device graph.";
}
for (const auto& x : arg_grad_ctxes) {
CHECK(x == default_ctx)
<< "Gradient array is in " << x << " while binding with ctx="
for (const auto& x : grad_store) {
CHECK(x.second.ctx() == default_ctx)
<< "Gradient array is in " << x.second.ctx() << " while binding with ctx="
<< default_ctx << ". All gradients must be in global context (" << default_ctx
<< ") unless group2ctx is specified for cross-device graph.";
}
return g;
}

// otherwise, use context assignment.
std::map<Context, int> ctx2id; // map ctx to device id
std::vector<Context> ctx_list; // index is device id
nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id
nnvm::DeviceAssignMap device_map; // map arg name to device id
std::map<Context, int> ctx2id;
std::vector<Context> ctx_list;
nnvm::DeviceVector device(idx.num_nodes(), -1);
nnvm::DeviceAssignMap device_map;

// loop through the user input ctx_map and
// populate maps and lists
for (auto &kv : ctx_map) {
if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one
ctx2id[kv.second] = static_cast<int>(ctx_list.size()); // assign device id to ctx
ctx_list.push_back(kv.second); // save ctx to the list
if (ctx2id.count(kv.second) == 0) {
ctx2id[kv.second] = static_cast<int>(ctx_list.size());
ctx_list.push_back(kv.second);
}
// assign device id to to the arg name with the corresponding ctx
device_map[kv.first] = ctx2id.at(kv.second);
}

// loop through all the rest of input nodes not specified
// in the ctx_map and populate maps and lists
size_t arg_top = 0, aux_top = 0;
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
Context ctx;
if (mutable_nodes.count(nid)) { // aux node is mutable
CHECK_LT(aux_top, aux_state_ctxes.size());
ctx = aux_state_ctxes[aux_top];
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_states.size());
ctx = aux_states[aux_top].ctx();
++aux_top;
} else { // regular input node is immutable
CHECK_LT(arg_top, in_arg_ctxes.size());
ctx = in_arg_ctxes[arg_top];
} else {
CHECK_LT(arg_top, in_args.size());
ctx = in_args[arg_top].ctx();
++arg_top;
}
if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id
ctx2id[ctx] = static_cast<int>(ctx_list.size()); // assign the current ctx with device id
ctx_list.push_back(ctx); // save the current ctx in the list
if (ctx2id.count(ctx) == 0) {
ctx2id[ctx] = static_cast<int>(ctx_list.size());
ctx_list.push_back(ctx);
}
device[nid] = ctx2id.at(ctx); // assign device id to the current node
device[nid] = ctx2id.at(ctx);
}

// loop through backward input nodes and populate maps and lists
// the backward input nodes is the gradient of the loss wrt the output
for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) {
const uint32_t nid = idx.outputs()[i].node_id;
Context ctx = arg_grad_ctxes[i - num_forward_outputs];
Context ctx = grad_store[i - num_forward_outputs].second.ctx();
if (ctx2id.count(ctx) == 0) {
ctx2id[ctx] = static_cast<int>(ctx_list.size());
ctx_list.push_back(ctx);
Expand All @@ -481,7 +482,6 @@ Graph AssignContext(Graph g,
device[nid] = devid;
}
}

g.attrs["device"] = std::make_shared<dmlc::any>(std::move(device));
g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy");
const auto& assigned_device = g.GetAttr<nnvm::DeviceVector>("device");
Expand Down

0 comments on commit 491700e

Please sign in to comment.