From c6385acd4cf32e3fbbd710b53055f63e406d6c41 Mon Sep 17 00:00:00 2001 From: Lin Min Date: Fri, 21 Aug 2015 23:03:34 +0800 Subject: [PATCH] Revert "support backward of backward" --- include/mxnet/operator.h | 2 +- include/mxnet/symbolic.h | 29 ++++++++++------------------- src/symbol/graph_executor.cc | 16 ++++++++-------- src/symbol/static_graph.cc | 25 ++++++++++--------------- src/symbol/symbol.cc | 24 ++++++++++++------------ 5 files changed, 41 insertions(+), 55 deletions(-) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index b44a50d2db7d..e60afe6948a7 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -353,7 +353,7 @@ class OperatorProperty { for (size_t i = 0; i < ret_index.size(); ++i) { ret[i] = all_data[ret_index[i]]; } - return ret; + return std::move(ret); } /*! * \brief create OperatorProperty diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index f97a695b7f19..df06c4913de8 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -85,36 +85,27 @@ class StaticGraph { /*! \brief inputs (node_id, index) for of the nodes*/ std::vector inputs; /*! - * \brief source node id; if this field is negative, it means this - * Node is a forward node. If this field is nonnegative, it - * means this Node is the gradient of the source node. - */ - int32_t source_id; - /*! - * \brief backward; if this field is true, that means this node - * represents the backward function of the op. Else, it - * represents the forward function. When it represents the - * backward function, itself has not op but shares from the - * source node. It is because the backward function shares the - * states from the forward, and they need to share op. + * \brief If this field is nonnegative, this indicates this + * Node is corresponds to a Backward Operation of Operator. + * backward_source_id will points to the corresponding Forward Node. * - * Since we support gradient of gradient, a forward node can also - * be the gradient of another node. See source id. + * For normal node, this field is -1. + * When the node is a Backward node, the op field will be nullptr */ - bool backward; + int32_t backward_source_id; /*! \brief default constructor */ - Node() : source_id(-1), backward(false) {} + Node() : backward_source_id(-1) {} /*! \return whether the node is forward op node */ inline bool is_forward() const { - return !backward && !is_variable(); + return op != nullptr; } /*! \return whether the node is backward op node */ inline bool is_backward() const { - return backward; + return backward_source_id != -1; } /*! \return whether the node is variable node */ inline bool is_variable() const { - return op == nullptr && source_id == -1; + return op == nullptr && !is_backward(); } }; /*! \brief all nodes in the graph */ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index d058242ef11e..a434f22a2fc6 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const { return node.op->ForwardResource(); } else { CHECK(node.is_backward()); - return graph_.nodes[node.source_id].op->BackwardResource(); + return graph_.nodes[node.backward_source_id].op->BackwardResource(); } } @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { return node.op->NumReturns(); } else if (node.is_backward()) { return static_cast( - graph_.nodes[node.source_id].op->ListArguments().size()); + graph_.nodes[node.backward_source_id].op->ListArguments().size()); } else { CHECK(node.is_variable()); return 1; @@ -121,11 +121,11 @@ inline std::vector > GraphExecutor::GetInplaceOption( remap[i].first = in_data[rmap_index[i].first]; remap[i].second = *static_cast(rmap_index[i].second); } - return remap; + return std::move(remap); } else { CHECK(node.is_backward()); // forward property - const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get(); + const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); std::vector out_grad_index(fwd->NumVisibleReturns()); std::vector in_data_index(fwd->ListArguments().size()); @@ -161,7 +161,7 @@ inline std::vector > GraphExecutor::GetInplaceOption( remap[i].first = *args_array[remap_index[i].first]; remap[i].second = *static_cast(remap_index[i].second); } - return remap; + return std::move(remap); } } @@ -196,7 +196,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { op_ctx_ptr->run_ctx = ctx; op->Forward(*op_ctx_ptr, in_data, req, out_data); }; - return exec; + return std::move(exec); } void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() { } else { CHECK(graph_.nodes[nid].is_backward()); op_node.op.reset(new BackwardOpWrapper( - graph_.nodes[graph_.nodes[nid].source_id].op.get(), - op_nodes_[graph_.nodes[nid].source_id].op)); + graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(), + op_nodes_[graph_.nodes[nid].backward_source_id].op)); } bool allow_cache = true; for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 82dc61193566..5eb0ad14a282 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -18,7 +18,7 @@ std::vector StaticGraph::TopoSort() const { ++out_degree[e.source_id]; } if (n.is_backward()) { - ++out_degree[n.source_id]; + ++out_degree[n.backward_source_id]; } } std::vector ret(nodes.size()); @@ -41,12 +41,12 @@ std::vector StaticGraph::TopoSort() const { } } if (n.is_backward()) { - if (--out_degree[n.source_id] == 0) { - queue.push(n.source_id); + if (--out_degree[n.backward_source_id] == 0) { + queue.push(n.backward_source_id); } } } - return ret; + return std::move(ret); } bool StaticGraph::InferNodeShapes(const std::vector &topo_order, @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } else if (nodes[nid].is_backward()) { // simply use shapes from forward pass to assign backward shape - const Node& forward = nodes[node.source_id]; + const Node& forward = nodes[node.backward_source_id]; CHECK(forward.is_forward()); std::vector& in_grad_shapes = (*node_out_shapes)[nid]; CHECK(in_grad_shapes.size() == forward.inputs.size()); @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } // consistent check for input shapes - auto& out_data_shapes = (*node_out_shapes)[node.source_id]; + auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id]; // use BackwardInputs to select entries corresponding to node.inputs auto in_shape = forward.op->BackwardInputs( out_data_shapes, in_grad_shapes, out_data_shapes); @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector *in_shape, if (nodes[i].is_forward()) { nout = nodes[i].op->NumReturns(); } else if (nodes[i].is_backward()) { - nout = static_cast(nodes[nodes[i].source_id].inputs.size()); + nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); } node_out_shapes[i].resize(nout); } @@ -161,7 +161,7 @@ StaticGraph::Node StaticGraph::CreateSumNode( os_size << grad_source.size(); agg_node.op->Init({{"size", os_size.str()}}); agg_node.inputs = grad_source; - return agg_node; + return std::move(agg_node); } void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, @@ -198,6 +198,7 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, uint32_t nid = *it; // skip variables if (nodes[nid].is_variable()) continue; + CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward"; // get out_grad and out_data entry std::vector out_grad, out_data; // nvisible is out_grad.size() @@ -228,13 +229,7 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, // Create a gradient backward node Node grad_node; // Point to the corresponding source - grad_node.source_id = nid; - // reverse the source node - grad_node.backward = !(nodes[grad_node.source_id].backward); - // if grad node is a forward node, needs to have its own OpProperty - if (!grad_node.backward) { - grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy()); - } + grad_node.backward_source_id = nid; // select out the dependent inputs grad_node.inputs = nodes[nid].op->BackwardInputs( out_grad, nodes[nid].inputs, out_data); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index aecac3dda487..54a5fe9422b2 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -1,7 +1,7 @@ /*! - * Copyright (c) 2015 by Contributors - * \file symbol.cc - * \brief symbol of mxnet + * Copyright (c) 2015 by Contributors + *\file symbol.cc + *\brief symbol of mxnet */ #include #include @@ -12,13 +12,13 @@ namespace mxnet { /*! - * \brief Node is represents node of an operator in the symbolic graph. + *\brief Node is represents node of an operator in the symbolic graph. * - * It stores connection to the inputs to function represented by OperatorProperty - * NOTE on data structure: there are three types of node: - * - Normal node: contains all the necessary elements of a graph. - * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. - * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. + *It stores connection to the inputs to function represented by OperatorProperty + *NOTE on data structure: there are three types of node: + *- Normal node: contains all the necessary elements of a graph. + *- OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. + *- Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { /*! \brief Operator of this node */ @@ -201,7 +201,7 @@ std::vector Symbol::ListReturns() const { } } } - return ret; + return std::move(ret); } Symbol Symbol::operator[] (size_t index) const { @@ -415,13 +415,13 @@ Symbol Symbol::CreateGroup(const std::vector &symbols) { for (const auto &s : symbols) { ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end()); } - return ret; + return std::move(ret); } Symbol Symbol::CreateVariable(const std::string &name) { Symbol s; s.heads_.push_back(DataEntry(std::make_shared(nullptr, name), 0)); - return s; + return std::move(s); } void Symbol::ToStaticGraph(StaticGraph *out_graph) const {