Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support backward of backward
Browse files Browse the repository at this point in the history
  • Loading branch information
mavenlin committed Aug 21, 2015
1 parent 0f77b3e commit 32a04c3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
29 changes: 19 additions & 10 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,36 @@ class StaticGraph {
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
* \brief source node id; if this field is negative, it means this
* Node is a forward node. If this field is nonnegative, it
* means this Node is the gradient of the source node.
*/
int32_t source_id;
/*!
* \brief backward; if this field is true, that means this node
* represents the backward function of the op. Else, it
* represents the forward function. When it represents the
* backward function, itself has not op but shares from the
* source node. It is because the backward function shares the
* states from the forward, and they need to share op.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
* Since we support gradient of gradient, a forward node can also
* be the gradient of another node. See source id.
*/
int32_t backward_source_id;
bool backward;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}
Node() : source_id(-1), backward(false) {}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
return !backward && !is_variable();
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
return backward;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
return op == nullptr && source_id == -1;
}
};
/*! \brief all nodes in the graph */
Expand Down
10 changes: 5 additions & 5 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const {
return node.op->ForwardResource();
} else {
CHECK(node.is_backward());
return graph_.nodes[node.backward_source_id].op->BackwardResource();
return graph_.nodes[node.source_id].op->BackwardResource();
}
}

Expand All @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const {
return node.op->NumReturns();
} else if (node.is_backward()) {
return static_cast<int>(
graph_.nodes[node.backward_source_id].op->ListArguments().size());
graph_.nodes[node.source_id].op->ListArguments().size());
} else {
CHECK(node.is_variable());
return 1;
Expand Down Expand Up @@ -125,7 +125,7 @@ inline std::vector<std::pair<T, T> > GraphExecutor::GetInplaceOption(
} else {
CHECK(node.is_backward());
// forward property
const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get();
const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get();

std::vector<int> out_grad_index(fwd->NumVisibleReturns());
std::vector<int> in_data_index(fwd->ListArguments().size());
Expand Down Expand Up @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() {
} else {
CHECK(graph_.nodes[nid].is_backward());
op_node.op.reset(new BackwardOpWrapper(
graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(),
op_nodes_[graph_.nodes[nid].backward_source_id].op));
graph_.nodes[graph_.nodes[nid].source_id].op.get(),
op_nodes_[graph_.nodes[nid].source_id].op));
}
bool allow_cache = true;
for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) {
Expand Down
21 changes: 13 additions & 8 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
++out_degree[e.source_id];
}
if (n.is_backward()) {
++out_degree[n.backward_source_id];
++out_degree[n.source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
Expand All @@ -41,8 +41,8 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
}
}
if (n.is_backward()) {
if (--out_degree[n.backward_source_id] == 0) {
queue.push(n.backward_source_id);
if (--out_degree[n.source_id] == 0) {
queue.push(n.source_id);
}
}
}
Expand Down Expand Up @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
} else if (nodes[nid].is_backward()) {
// simply use shapes from forward pass to assign backward shape
const Node& forward = nodes[node.backward_source_id];
const Node& forward = nodes[node.source_id];
CHECK(forward.is_forward());
std::vector<TShape>& in_grad_shapes = (*node_out_shapes)[nid];
CHECK(in_grad_shapes.size() == forward.inputs.size());
Expand All @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
}
// consistent check for input shapes
auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id];
auto& out_data_shapes = (*node_out_shapes)[node.source_id];
// use BackwardInputs to select entries corresponding to node.inputs
auto in_shape = forward.op->BackwardInputs(
out_data_shapes, in_grad_shapes, out_data_shapes);
Expand Down Expand Up @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
if (nodes[i].is_forward()) {
nout = nodes[i].op->NumReturns();
} else if (nodes[i].is_backward()) {
nout = static_cast<int>(nodes[nodes[i].backward_source_id].inputs.size());
nout = static_cast<int>(nodes[nodes[i].source_id].inputs.size());
}
node_out_shapes[i].resize(nout);
}
Expand Down Expand Up @@ -198,7 +198,6 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
uint32_t nid = *it;
// skip variables
if (nodes[nid].is_variable()) continue;
CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward";
// get out_grad and out_data entry
std::vector<DataEntry> out_grad, out_data;
// nvisible is out_grad.size()
Expand Down Expand Up @@ -229,7 +228,13 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
// Create a gradient backward node
Node grad_node;
// Point to the corresponding source
grad_node.backward_source_id = nid;
grad_node.source_id = nid;
// reverse the source node
grad_node.backward = !(nodes[grad_node.source_id].backward);
// if grad node is a forward node, needs to have its own OpProperty
if (!grad_node.backward) {
grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy());
}
// select out the dependent inputs
grad_node.inputs = nodes[nid].op->BackwardInputs(
out_grad, nodes[nid].inputs, out_data);
Expand Down

0 comments on commit 32a04c3

Please sign in to comment.