Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #23 from mavenlin/master
Browse files Browse the repository at this point in the history
support backward of backward
  • Loading branch information
mavenlin committed Aug 21, 2015
2 parents 0635103 + ec1a11f commit 860d2ee
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 41 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ class OperatorProperty {
for (size_t i = 0; i < ret_index.size(); ++i) {
ret[i] = all_data[ret_index[i]];
}
return std::move(ret);
return ret;
}
/*!
* \brief create OperatorProperty
Expand Down
29 changes: 19 additions & 10 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,36 @@ class StaticGraph {
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
* \brief source node id; if this field is negative, it means this
* Node is a forward node. If this field is nonnegative, it
* means this Node is the gradient of the source node.
*/
int32_t source_id;
/*!
* \brief backward; if this field is true, that means this node
* represents the backward function of the op. Else, it
* represents the forward function. When it represents the
* backward function, itself has not op but shares from the
* source node. It is because the backward function shares the
* states from the forward, and they need to share op.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
* Since we support gradient of gradient, a forward node can also
* be the gradient of another node. See source id.
*/
int32_t backward_source_id;
bool backward;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}
Node() : source_id(-1), backward(false) {}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
return !backward && !is_variable();
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
return backward;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
return op == nullptr && source_id == -1;
}
};
/*! \brief all nodes in the graph */
Expand Down
16 changes: 8 additions & 8 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const {
return node.op->ForwardResource();
} else {
CHECK(node.is_backward());
return graph_.nodes[node.backward_source_id].op->BackwardResource();
return graph_.nodes[node.source_id].op->BackwardResource();
}
}

Expand All @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const {
return node.op->NumReturns();
} else if (node.is_backward()) {
return static_cast<int>(
graph_.nodes[node.backward_source_id].op->ListArguments().size());
graph_.nodes[node.source_id].op->ListArguments().size());
} else {
CHECK(node.is_variable());
return 1;
Expand Down Expand Up @@ -121,11 +121,11 @@ inline std::vector<std::pair<T, T> > GraphExecutor::GetInplaceOption(
remap[i].first = in_data[rmap_index[i].first];
remap[i].second = *static_cast<const T*>(rmap_index[i].second);
}
return std::move(remap);
return remap;
} else {
CHECK(node.is_backward());
// forward property
const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get();
const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get();

std::vector<int> out_grad_index(fwd->NumVisibleReturns());
std::vector<int> in_data_index(fwd->ListArguments().size());
Expand Down Expand Up @@ -161,7 +161,7 @@ inline std::vector<std::pair<T, T> > GraphExecutor::GetInplaceOption(
remap[i].first = *args_array[remap_index[i].first];
remap[i].second = *static_cast<T*>(remap_index[i].second);
}
return std::move(remap);
return remap;
}
}

Expand Down Expand Up @@ -196,7 +196,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) {
op_ctx_ptr->run_ctx = ctx;
op->Forward(*op_ctx_ptr, in_data, req, out_data);
};
return std::move(exec);
return exec;
}

void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) {
Expand Down Expand Up @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() {
} else {
CHECK(graph_.nodes[nid].is_backward());
op_node.op.reset(new BackwardOpWrapper(
graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(),
op_nodes_[graph_.nodes[nid].backward_source_id].op));
graph_.nodes[graph_.nodes[nid].source_id].op.get(),
op_nodes_[graph_.nodes[nid].source_id].op));
}
bool allow_cache = true;
for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) {
Expand Down
25 changes: 15 additions & 10 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
++out_degree[e.source_id];
}
if (n.is_backward()) {
++out_degree[n.backward_source_id];
++out_degree[n.source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
Expand All @@ -41,12 +41,12 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
}
}
if (n.is_backward()) {
if (--out_degree[n.backward_source_id] == 0) {
queue.push(n.backward_source_id);
if (--out_degree[n.source_id] == 0) {
queue.push(n.source_id);
}
}
}
return std::move(ret);
return ret;
}

bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
Expand Down Expand Up @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
} else if (nodes[nid].is_backward()) {
// simply use shapes from forward pass to assign backward shape
const Node& forward = nodes[node.backward_source_id];
const Node& forward = nodes[node.source_id];
CHECK(forward.is_forward());
std::vector<TShape>& in_grad_shapes = (*node_out_shapes)[nid];
CHECK(in_grad_shapes.size() == forward.inputs.size());
Expand All @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
}
// consistent check for input shapes
auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id];
auto& out_data_shapes = (*node_out_shapes)[node.source_id];
// use BackwardInputs to select entries corresponding to node.inputs
auto in_shape = forward.op->BackwardInputs(
out_data_shapes, in_grad_shapes, out_data_shapes);
Expand Down Expand Up @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
if (nodes[i].is_forward()) {
nout = nodes[i].op->NumReturns();
} else if (nodes[i].is_backward()) {
nout = static_cast<int>(nodes[nodes[i].backward_source_id].inputs.size());
nout = static_cast<int>(nodes[nodes[i].source_id].inputs.size());
}
node_out_shapes[i].resize(nout);
}
Expand Down Expand Up @@ -161,7 +161,7 @@ StaticGraph::Node StaticGraph::CreateSumNode(
os_size << grad_source.size();
agg_node.op->Init({{"size", os_size.str()}});
agg_node.inputs = grad_source;
return std::move(agg_node);
return agg_node;
}

void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
Expand Down Expand Up @@ -198,7 +198,6 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
uint32_t nid = *it;
// skip variables
if (nodes[nid].is_variable()) continue;
CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward";
// get out_grad and out_data entry
std::vector<DataEntry> out_grad, out_data;
// nvisible is out_grad.size()
Expand Down Expand Up @@ -229,7 +228,13 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
// Create a gradient backward node
Node grad_node;
// Point to the corresponding source
grad_node.backward_source_id = nid;
grad_node.source_id = nid;
// reverse the source node
grad_node.backward = !(nodes[grad_node.source_id].backward);
// if grad node is a forward node, needs to have its own OpProperty
if (!grad_node.backward) {
grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy());
}
// select out the dependent inputs
grad_node.inputs = nodes[nid].op->BackwardInputs(
out_grad, nodes[nid].inputs, out_data);
Expand Down
24 changes: 12 additions & 12 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2015 by Contributors
*\file symbol.cc
*\brief symbol of mxnet
* Copyright (c) 2015 by Contributors
* \file symbol.cc
* \brief symbol of mxnet
*/
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
Expand All @@ -12,13 +12,13 @@

namespace mxnet {
/*!
*\brief Node is represents node of an operator in the symbolic graph.
* \brief Node is represents node of an operator in the symbolic graph.
*
*It stores connection to the inputs to function represented by OperatorProperty
*NOTE on data structure: there are three types of node:
*- Normal node: contains all the necessary elements of a graph.
*- OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied.
*- Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed.
* It stores connection to the inputs to function represented by OperatorProperty
* NOTE on data structure: there are three types of node:
* - Normal node: contains all the necessary elements of a graph.
* - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied.
* - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed.
*/
struct Symbol::Node {
/*! \brief Operator of this node */
Expand Down Expand Up @@ -201,7 +201,7 @@ std::vector<std::string> Symbol::ListReturns() const {
}
}
}
return std::move(ret);
return ret;
}

Symbol Symbol::operator[] (size_t index) const {
Expand Down Expand Up @@ -415,13 +415,13 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
for (const auto &s : symbols) {
ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end());
}
return std::move(ret);
return ret;
}

Symbol Symbol::CreateVariable(const std::string &name) {
Symbol s;
s.heads_.push_back(DataEntry(std::make_shared<Node>(nullptr, name), 0));
return std::move(s);
return s;
}

void Symbol::ToStaticGraph(StaticGraph *out_graph) const {
Expand Down

0 comments on commit 860d2ee

Please sign in to comment.