diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index df06c4913de8..f97a695b7f19 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -85,27 +85,36 @@ class StaticGraph { /*! \brief inputs (node_id, index) for of the nodes*/ std::vector inputs; /*! - * \brief If this field is nonnegative, this indicates this - * Node is corresponds to a Backward Operation of Operator. - * backward_source_id will points to the corresponding Forward Node. + * \brief source node id; if this field is negative, it means this + * Node is a forward node. If this field is nonnegative, it + * means this Node is the gradient of the source node. + */ + int32_t source_id; + /*! + * \brief backward; if this field is true, that means this node + * represents the backward function of the op. Else, it + * represents the forward function. When it represents the + * backward function, itself has not op but shares from the + * source node. It is because the backward function shares the + * states from the forward, and they need to share op. * - * For normal node, this field is -1. - * When the node is a Backward node, the op field will be nullptr + * Since we support gradient of gradient, a forward node can also + * be the gradient of another node. See source id. */ - int32_t backward_source_id; + bool backward; /*! \brief default constructor */ - Node() : backward_source_id(-1) {} + Node() : source_id(-1), backward(false) {} /*! \return whether the node is forward op node */ inline bool is_forward() const { - return op != nullptr; + return !backward && !is_variable(); } /*! \return whether the node is backward op node */ inline bool is_backward() const { - return backward_source_id != -1; + return backward; } /*! \return whether the node is variable node */ inline bool is_variable() const { - return op == nullptr && !is_backward(); + return op == nullptr && source_id == -1; } }; /*! \brief all nodes in the graph */ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 8dbadb34e24e..d058242ef11e 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const { return node.op->ForwardResource(); } else { CHECK(node.is_backward()); - return graph_.nodes[node.backward_source_id].op->BackwardResource(); + return graph_.nodes[node.source_id].op->BackwardResource(); } } @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { return node.op->NumReturns(); } else if (node.is_backward()) { return static_cast( - graph_.nodes[node.backward_source_id].op->ListArguments().size()); + graph_.nodes[node.source_id].op->ListArguments().size()); } else { CHECK(node.is_variable()); return 1; @@ -125,7 +125,7 @@ inline std::vector > GraphExecutor::GetInplaceOption( } else { CHECK(node.is_backward()); // forward property - const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); + const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get(); std::vector out_grad_index(fwd->NumVisibleReturns()); std::vector in_data_index(fwd->ListArguments().size()); @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() { } else { CHECK(graph_.nodes[nid].is_backward()); op_node.op.reset(new BackwardOpWrapper( - graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(), - op_nodes_[graph_.nodes[nid].backward_source_id].op)); + graph_.nodes[graph_.nodes[nid].source_id].op.get(), + op_nodes_[graph_.nodes[nid].source_id].op)); } bool allow_cache = true; for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index c9ed278b8f7e..82dc61193566 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -18,7 +18,7 @@ std::vector StaticGraph::TopoSort() const { ++out_degree[e.source_id]; } if (n.is_backward()) { - ++out_degree[n.backward_source_id]; + ++out_degree[n.source_id]; } } std::vector ret(nodes.size()); @@ -41,8 +41,8 @@ std::vector StaticGraph::TopoSort() const { } } if (n.is_backward()) { - if (--out_degree[n.backward_source_id] == 0) { - queue.push(n.backward_source_id); + if (--out_degree[n.source_id] == 0) { + queue.push(n.source_id); } } } @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } else if (nodes[nid].is_backward()) { // simply use shapes from forward pass to assign backward shape - const Node& forward = nodes[node.backward_source_id]; + const Node& forward = nodes[node.source_id]; CHECK(forward.is_forward()); std::vector& in_grad_shapes = (*node_out_shapes)[nid]; CHECK(in_grad_shapes.size() == forward.inputs.size()); @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } } // consistent check for input shapes - auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id]; + auto& out_data_shapes = (*node_out_shapes)[node.source_id]; // use BackwardInputs to select entries corresponding to node.inputs auto in_shape = forward.op->BackwardInputs( out_data_shapes, in_grad_shapes, out_data_shapes); @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector *in_shape, if (nodes[i].is_forward()) { nout = nodes[i].op->NumReturns(); } else if (nodes[i].is_backward()) { - nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); + nout = static_cast(nodes[nodes[i].source_id].inputs.size()); } node_out_shapes[i].resize(nout); } @@ -198,7 +198,6 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, uint32_t nid = *it; // skip variables if (nodes[nid].is_variable()) continue; - CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward"; // get out_grad and out_data entry std::vector out_grad, out_data; // nvisible is out_grad.size() @@ -229,7 +228,13 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, // Create a gradient backward node Node grad_node; // Point to the corresponding source - grad_node.backward_source_id = nid; + grad_node.source_id = nid; + // reverse the source node + grad_node.backward = !(nodes[grad_node.source_id].backward); + // if grad node is a forward node, needs to have its own OpProperty + if (!grad_node.backward) { + grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy()); + } // select out the dependent inputs grad_node.inputs = nodes[nid].op->BackwardInputs( out_grad, nodes[nid].inputs, out_data);