From 0f77b3eb5453b67ab51f33d6ac71cfd00af0063a Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 21 Aug 2015 13:51:51 +0800 Subject: [PATCH 1/3] return value does not need std move http://stackoverflow.com/questions/11817873/using-stdmove-when-returning-a-value-from-a-function-to-avoid-to-copy --- include/mxnet/operator.h | 2 +- src/symbol/graph_executor.cc | 6 +++--- src/symbol/static_graph.cc | 4 ++-- src/symbol/symbol.cc | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index e60afe6948a7..b44a50d2db7d 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -353,7 +353,7 @@ class OperatorProperty { for (size_t i = 0; i < ret_index.size(); ++i) { ret[i] = all_data[ret_index[i]]; } - return std::move(ret); + return ret; } /*! * \brief create OperatorProperty diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index a434f22a2fc6..8dbadb34e24e 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -121,7 +121,7 @@ inline std::vector > GraphExecutor::GetInplaceOption( remap[i].first = in_data[rmap_index[i].first]; remap[i].second = *static_cast(rmap_index[i].second); } - return std::move(remap); + return remap; } else { CHECK(node.is_backward()); // forward property @@ -161,7 +161,7 @@ inline std::vector > GraphExecutor::GetInplaceOption( remap[i].first = *args_array[remap_index[i].first]; remap[i].second = *static_cast(remap_index[i].second); } - return std::move(remap); + return remap; } } @@ -196,7 +196,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { op_ctx_ptr->run_ctx = ctx; op->Forward(*op_ctx_ptr, in_data, req, out_data); }; - return std::move(exec); + return exec; } void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 5eb0ad14a282..c9ed278b8f7e 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -46,7 +46,7 @@ std::vector StaticGraph::TopoSort() const { } } } - return std::move(ret); + return ret; } bool StaticGraph::InferNodeShapes(const std::vector &topo_order, @@ -161,7 +161,7 @@ StaticGraph::Node StaticGraph::CreateSumNode( os_size << grad_source.size(); agg_node.op->Init({{"size", os_size.str()}}); agg_node.inputs = grad_source; - return std::move(agg_node); + return agg_node; } void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 54a5fe9422b2..56ad0a869540 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -201,7 +201,7 @@ std::vector Symbol::ListReturns() const { } } } - return std::move(ret); + return ret; } Symbol Symbol::operator[] (size_t index) const { @@ -415,13 +415,13 @@ Symbol Symbol::CreateGroup(const std::vector &symbols) { for (const auto &s : symbols) { ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); } - return std::move(ret); + return ret; } Symbol Symbol::CreateVariable(const std::string &name) { Symbol s; s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); - return std::move(s); + return s; } void Symbol::ToStaticGraph(StaticGraph *out_graph) const { From 32a04c3dab2362a69428f972ce2f9fdfa8b9ef4c Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 21 Aug 2015 17:32:23 +0800 Subject: [PATCH 2/3] support backward of backward --- include/mxnet/symbolic.h | 29 +++++++++++++++++++---------- src/symbol/graph_executor.cc | 10 +++++----- src/symbol/static_graph.cc | 21 +++++++++++++-------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index df06c4913de8..f97a695b7f19 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -85,27 +85,36 @@ class StaticGraph { /*! \brief inputs (node_id, index) for of the nodes*/ std::vector inputs; /*! - * \brief If this field is nonnegative, this indicates this - * Node is corresponds to a Backward Operation of Operator. - * backward_source_id will points to the corresponding Forward Node. + * \brief source node id; if this field is negative, it means this + * Node is a forward node. If this field is nonnegative, it + * means this Node is the gradient of the source node. + */ + int32_t source_id; + /*! + * \brief backward; if this field is true, that means this node + * represents the backward function of the op. Else, it + * represents the forward function. When it represents the + * backward function, itself has not op but shares from the + * source node. It is because the backward function shares the + * states from the forward, and they need to share op. * - * For normal node, this field is -1. - * When the node is a Backward node, the op field will be nullptr + * Since we support gradient of gradient, a forward node can also + * be the gradient of another node. See source id. */ - int32_t backward_source_id; + bool backward; /*! \brief default constructor */ - Node() : backward_source_id(-1) {} + Node() : source_id(-1), backward(false) {} /*! \return whether the node is forward op node */ inline bool is_forward() const { - return op != nullptr; + return !backward && !is_variable(); } /*! \return whether the node is backward op node */ inline bool is_backward() const { - return backward_source_id != -1; + return backward; } /*! \return whether the node is variable node */ inline bool is_variable() const { - return op == nullptr && !is_backward(); + return op == nullptr && source_id == -1; } }; /*! \brief all nodes in the graph */ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 8dbadb34e24e..d058242ef11e 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const { return node.op->ForwardResource(); } else { CHECK(node.is_backward()); - return graph_.nodes[node.backward_source_id].op->BackwardResource(); + return graph_.nodes[node.source_id].op->BackwardResource(); } } @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { return node.op->NumReturns(); } else if (node.is_backward()) { return static_cast( - graph_.nodes[node.backward_source_id].op->ListArguments().size()); + graph_.nodes[node.source_id].op->ListArguments().size()); } else { CHECK(node.is_variable()); return 1; @@ -125,7 +125,7 @@ inline std::vector > GraphExecutor::GetInplaceOption( } else { CHECK(node.is_backward()); // forward property - const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); + const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get(); std::vector out_grad_index(fwd->NumVisibleReturns()); std::vector in_data_index(fwd->ListArguments().size()); @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() { } else { CHECK(graph_.nodes[nid].is_backward()); op_node.op.reset(new BackwardOpWrapper( - graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(), - op_nodes_[graph_.nodes[nid].backward_source_id].op)); + graph_.nodes[graph_.nodes[nid].source_id].op.get(), + op_nodes_[graph_.nodes[nid].source_id].op)); } bool allow_cache = true; for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index c9ed278b8f7e..82dc61193566 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -18,7 +18,7 @@ std::vector StaticGraph::TopoSort() const { ++out_degree[e.source_id]; } if (n.is_backward()) { - ++out_degree[n.backward_source_id]; + ++out_degree[n.source_id]; } } std::vector ret(nodes.size()); @@ -41,8 +41,8 @@ std::vector StaticGraph::TopoSort() const { } } if (n.is_backward()) { - if (--out_degree[n.backward_source_id] == 0) { - queue.push(n.backward_source_id); + if (--out_degree[n.source_id] == 0) { + queue.push(n.source_id); } } } @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } else if (nodes[nid].is_backward()) { // simply use shapes from forward pass to assign backward shape - const Node& forward = nodes[node.backward_source_id]; + const Node& forward = nodes[node.source_id]; CHECK(forward.is_forward()); std::vector& in_grad_shapes = (*node_out_shapes)[nid]; CHECK(in_grad_shapes.size() == forward.inputs.size()); @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } // consistent check for input shapes - auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id]; + auto& out_data_shapes = (*node_out_shapes)[node.source_id]; // use BackwardInputs to select entries corresponding to node.inputs auto in_shape = forward.op->BackwardInputs( out_data_shapes, in_grad_shapes, out_data_shapes); @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector *in_shape, if (nodes[i].is_forward()) { nout = nodes[i].op->NumReturns(); } else if (nodes[i].is_backward()) { - nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); + nout = static_cast(nodes[nodes[i].source_id].inputs.size()); } node_out_shapes[i].resize(nout); } @@ -198,7 +198,6 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, uint32_t nid = *it; // skip variables if (nodes[nid].is_variable()) continue; - CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward"; // get out_grad and out_data entry std::vector out_grad, out_data; // nvisible is out_grad.size() @@ -229,7 +228,13 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, // Create a gradient backward node Node grad_node; // Point to the corresponding source - grad_node.backward_source_id = nid; + grad_node.source_id = nid; + // reverse the source node + grad_node.backward = !(nodes[grad_node.source_id].backward); + // if grad node is a forward node, needs to have its own OpProperty + if (!grad_node.backward) { + grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy()); + } // select out the dependent inputs grad_node.inputs = nodes[nid].op->BackwardInputs( out_grad, nodes[nid].inputs, out_data); From ec1a11f40be78e3d9137f75a1fbf04d12252ab5e Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 21 Aug 2015 17:44:46 +0800 Subject: [PATCH 3/3] format annotation --- src/symbol/symbol.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 56ad0a869540..aecac3dda487 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -1,7 +1,7 @@ /*! - * Copyright (c) 2015 by Contributors - *\file symbol.cc - *\brief symbol of mxnet + * Copyright (c) 2015 by Contributors + * \file symbol.cc + * \brief symbol of mxnet */ #include #include @@ -12,13 +12,13 @@ namespace mxnet { /*! - *\brief Node is represents node of an operator in the symbolic graph. + * \brief Node is represents node of an operator in the symbolic graph. * - *It stores connection to the inputs to function represented by OperatorProperty - *NOTE on data structure: there are three types of node: - *- Normal node: contains all the necessary elements of a graph. - *- OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. - *- Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. + * It stores connection to the inputs to function represented by OperatorProperty + * NOTE on data structure: there are three types of node: + * - Normal node: contains all the necessary elements of a graph. + * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. + * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { /*! \brief Operator of this node */