From 491700e97b5737176ffd2d33c0552209586f61d7 Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 3 May 2017 20:54:56 -0700 Subject: [PATCH] Refactor --- src/executor/graph_executor.cc | 194 ++++++++++++++++----------------- 1 file changed, 97 insertions(+), 97 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7f3f50eaea15..b81b102f3299 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -175,10 +175,12 @@ inline ValueType get_node_attr( } } -nnvm::Graph GraphExecutor::InitFullGraphV1( - nnvm::Symbol symbol, - const std::vector& grad_req_type, - const std::vector& arg_grad_store) { +/*! + * \brief Create the graph for backward pass. + * This is triggered by both simple_bind and bind flows. + */ +nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, + const std::vector& grad_req_types) { using nnvm::NodePtr; using nnvm::NodeEntry; // initial information @@ -188,7 +190,7 @@ nnvm::Graph GraphExecutor::InitFullGraphV1( nnvm::Graph g; g.outputs = symbol.outputs; bool need_grad = false; - for (OpReqType req : grad_req_type) { + for (OpReqType req : grad_req_types) { if (req != kNullOp) need_grad = true; } if (!need_grad) return g; @@ -199,10 +201,8 @@ nnvm::Graph GraphExecutor::InitFullGraphV1( } std::vector args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs); std::vector xs; - for (size_t i = 0; i < grad_req_type.size(); ++i) { - if (grad_req_type[i] != kNullOp) { - grad_store_.emplace_back( - std::make_pair(grad_req_type[i], arg_grad_store[i])); + for (size_t i = 0; i < grad_req_types.size(); ++i) { + if (grad_req_types[i] != kNullOp) { xs.emplace_back(NodeEntry{args[i], 0, 0}); } } @@ -238,12 +238,10 @@ nnvm::Graph GraphExecutor::InitFullGraphV1( return g; } -/*! - * \brief Create the graph for backward pass. - * This is triggered by both simple_bind and bind flows. - */ -nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_types) { +nnvm::Graph GraphExecutor::InitFullGraphV1( + nnvm::Symbol symbol, + const std::vector& grad_req_type, + const std::vector& arg_grad_store) { using nnvm::NodePtr; using nnvm::NodeEntry; // initial information @@ -253,7 +251,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, nnvm::Graph g; g.outputs = symbol.outputs; bool need_grad = false; - for (OpReqType req : grad_req_types) { + for (OpReqType req : grad_req_type) { if (req != kNullOp) need_grad = true; } if (!need_grad) return g; @@ -264,8 +262,10 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, } std::vector args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs); std::vector xs; - for (size_t i = 0; i < grad_req_types.size(); ++i) { - if (grad_req_types[i] != kNullOp) { + for (size_t i = 0; i < grad_req_type.size(); ++i) { + if (grad_req_type[i] != kNullOp) { + grad_store_.emplace_back( + std::make_pair(grad_req_type[i], arg_grad_store[i])); xs.emplace_back(NodeEntry{args[i], 0, 0}); } } @@ -301,71 +301,83 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, return g; } -// pass to assign context to the graph -Graph AssignContextV1(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_args, - const std::vector >& grad_store, - const std::vector& aux_states, - size_t num_forward_inputs, - size_t num_forward_outputs) { +/*! + * \brief Assign context to the graph. + * This is triggered by both simple_bind and bind flows. + */ +Graph AssignContext(Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + size_t num_forward_inputs, + size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); const auto& mutable_nodes = idx.mutable_input_nodes(); // default use default context. if (ctx_map.size() == 0) { g.attrs["context"] = std::make_shared( ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_args) { - CHECK(x.ctx() == default_ctx) - << "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx + for (const auto& x : in_arg_ctxes) { + CHECK(x == default_ctx) + << "Input array is in " << x << " while binding with ctx=" << default_ctx << ". All arguments must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } - for (const auto& x : grad_store) { - CHECK(x.second.ctx() == default_ctx) - << "Gradient array is in " << x.second.ctx() << " while binding with ctx=" + for (const auto& x : arg_grad_ctxes) { + CHECK(x == default_ctx) + << "Gradient array is in " << x << " while binding with ctx=" << default_ctx << ". All gradients must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } return g; } + // otherwise, use context assignment. - std::map ctx2id; - std::vector ctx_list; - nnvm::DeviceVector device(idx.num_nodes(), -1); - nnvm::DeviceAssignMap device_map; + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id + nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id + nnvm::DeviceAssignMap device_map; // map arg name to device id + // loop through the user input ctx_map and + // populate maps and lists for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { - ctx2id[kv.second] = static_cast(ctx_list.size()); - ctx_list.push_back(kv.second); + if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one + ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx + ctx_list.push_back(kv.second); // save ctx to the list } + // assign device id to to the arg name with the corresponding ctx device_map[kv.first] = ctx2id.at(kv.second); } + // loop through all the rest of input nodes not specified + // in the ctx_map and populate maps and lists size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); Context ctx; - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_states.size()); - ctx = aux_states[aux_top].ctx(); + if (mutable_nodes.count(nid)) { // aux node is mutable + CHECK_LT(aux_top, aux_state_ctxes.size()); + ctx = aux_state_ctxes[aux_top]; ++aux_top; - } else { - CHECK_LT(arg_top, in_args.size()); - ctx = in_args[arg_top].ctx(); + } else { // regular input node is immutable + CHECK_LT(arg_top, in_arg_ctxes.size()); + ctx = in_arg_ctxes[arg_top]; ++arg_top; } - if (ctx2id.count(ctx) == 0) { - ctx2id[ctx] = static_cast(ctx_list.size()); - ctx_list.push_back(ctx); + if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id + ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id + ctx_list.push_back(ctx); // save the current ctx in the list } - device[nid] = ctx2id.at(ctx); + device[nid] = ctx2id.at(ctx); // assign device id to the current node } + + // loop through backward input nodes and populate maps and lists + // the backward input nodes is the gradient of the loss wrt the output for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = grad_store[i - num_forward_outputs].second.ctx(); + Context ctx = arg_grad_ctxes[i - num_forward_outputs]; if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast(ctx_list.size()); ctx_list.push_back(ctx); @@ -377,6 +389,7 @@ Graph AssignContextV1(Graph g, device[nid] = devid; } } + g.attrs["device"] = std::make_shared(std::move(device)); g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_device = g.GetAttr("device"); @@ -393,83 +406,71 @@ Graph AssignContextV1(Graph g, return g; } -/*! - * \brief Assign context to the graph. - * This is triggered by both simple_bind and bind flows. - */ -Graph AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - size_t num_forward_inputs, - size_t num_forward_outputs) { +// pass to assign context to the graph +Graph AssignContextV1(Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_args, + const std::vector >& grad_store, + const std::vector& aux_states, + size_t num_forward_inputs, + size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); const auto& mutable_nodes = idx.mutable_input_nodes(); // default use default context. if (ctx_map.size() == 0) { g.attrs["context"] = std::make_shared( ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_arg_ctxes) { - CHECK(x == default_ctx) - << "Input array is in " << x << " while binding with ctx=" << default_ctx + for (const auto& x : in_args) { + CHECK(x.ctx() == default_ctx) + << "Input array is in " << x.ctx() << " while binding with ctx=" << default_ctx << ". All arguments must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } - for (const auto& x : arg_grad_ctxes) { - CHECK(x == default_ctx) - << "Gradient array is in " << x << " while binding with ctx=" + for (const auto& x : grad_store) { + CHECK(x.second.ctx() == default_ctx) + << "Gradient array is in " << x.second.ctx() << " while binding with ctx=" << default_ctx << ". All gradients must be in global context (" << default_ctx << ") unless group2ctx is specified for cross-device graph."; } return g; } - // otherwise, use context assignment. - std::map ctx2id; // map ctx to device id - std::vector ctx_list; // index is device id - nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id - nnvm::DeviceAssignMap device_map; // map arg name to device id + std::map ctx2id; + std::vector ctx_list; + nnvm::DeviceVector device(idx.num_nodes(), -1); + nnvm::DeviceAssignMap device_map; - // loop through the user input ctx_map and - // populate maps and lists for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one - ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx - ctx_list.push_back(kv.second); // save ctx to the list + if (ctx2id.count(kv.second) == 0) { + ctx2id[kv.second] = static_cast(ctx_list.size()); + ctx_list.push_back(kv.second); } - // assign device id to to the arg name with the corresponding ctx device_map[kv.first] = ctx2id.at(kv.second); } - // loop through all the rest of input nodes not specified - // in the ctx_map and populate maps and lists size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); Context ctx; - if (mutable_nodes.count(nid)) { // aux node is mutable - CHECK_LT(aux_top, aux_state_ctxes.size()); - ctx = aux_state_ctxes[aux_top]; + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_states.size()); + ctx = aux_states[aux_top].ctx(); ++aux_top; - } else { // regular input node is immutable - CHECK_LT(arg_top, in_arg_ctxes.size()); - ctx = in_arg_ctxes[arg_top]; + } else { + CHECK_LT(arg_top, in_args.size()); + ctx = in_args[arg_top].ctx(); ++arg_top; } - if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id - ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id - ctx_list.push_back(ctx); // save the current ctx in the list + if (ctx2id.count(ctx) == 0) { + ctx2id[ctx] = static_cast(ctx_list.size()); + ctx_list.push_back(ctx); } - device[nid] = ctx2id.at(ctx); // assign device id to the current node + device[nid] = ctx2id.at(ctx); } - - // loop through backward input nodes and populate maps and lists - // the backward input nodes is the gradient of the loss wrt the output for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[i - num_forward_outputs]; + Context ctx = grad_store[i - num_forward_outputs].second.ctx(); if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast(ctx_list.size()); ctx_list.push_back(ctx); @@ -481,7 +482,6 @@ Graph AssignContext(Graph g, device[nid] = devid; } } - g.attrs["device"] = std::make_shared(std::move(device)); g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_device = g.GetAttr("device");