Skip to content

Commit

Permalink
change to match by both op and name (#109)
Browse files Browse the repository at this point in the history
* change to match by both op and name

* add zero prop for non differentiable ops

* Update infer_shape_type.cc
  • Loading branch information
piiswrong authored Mar 31, 2017
1 parent 0d64855 commit ddf3c17
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
9 changes: 8 additions & 1 deletion include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ inline Graph PlaceDevice(Graph graph,
* \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like.
* \param zero_ops Optional, list of operators that outputs a single zero array. The first one
* must be zeros_like.
* \return A new graph, whose outputs correspond to inputs of xs.
*/
inline Graph Gradient(
Expand All @@ -140,7 +142,8 @@ inline Graph Gradient(
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr,
std::function<NodeEntry(const NodeEntry& src, const NodeEntry &like)>
attr_hint_fun = nullptr) {
attr_hint_fun = nullptr,
std::vector<const Op*> zero_ops = std::vector<const Op*>()) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));

graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
Expand All @@ -157,6 +160,10 @@ inline Graph Gradient(
graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
}

if (zero_ops.size()) {
graph.attrs["zero_ops"] = std::make_shared<any>(std::move(zero_ops));
}

return ApplyPass(std::move(graph), "Gradient");
}

Expand Down
51 changes: 47 additions & 4 deletions src/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
}
}

bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
const std::vector<const Op*>& zero_ops) {
if (!grads.size() || !zero_ops.size()) return false;
for (const auto& g : grads) {
bool found = false;
for (const auto& op : zero_ops) {
if (g.node->op() == op) {
found = true;
break;
}
}
if (!found) return false;
}
return true;
}

// helper entry
struct GradEntry {
#ifdef _MSC_VER
Expand Down Expand Up @@ -71,6 +87,10 @@ Graph Gradient(Graph src) {
if (src.attrs.count("attr_hint_fun") != 0) {
attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
}
std::vector<const Op*> zero_ops;
if (src.attrs.count("zero_ops") != 0) {
zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
}

// topo sort
std::vector<NodePtr> topo_order;
Expand Down Expand Up @@ -130,10 +150,33 @@ Graph Gradient(Graph src) {
}
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
} else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
for (index_t i = 0; i < fwd_node->num_inputs(); ++i) {
std::ostringstream os;
if (1 == fwd_node->num_inputs()) {
os << fwd_node->attrs.name << "_backward";
} else {
os << fwd_node->attrs.name << "_in" << i << "_backward";
}
auto p = Node::Create();
p->attrs.op = zero_ops[0];
p->attrs.name = os.str();
p->inputs.push_back(fwd_node->inputs[i]);
p->control_deps.emplace_back(fwd_node);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0});
}
} else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
auto git = input_grads.begin();
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index];
Expand Down
6 changes: 4 additions & 2 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ Graph InferAttr(Graph &&ret,
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
NodePtr fwd_ptr = inode.source->control_deps[0];
CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
// use gradient function to find out the correspondence.
std::vector<NodeEntry> ograd(fwd_ptr->num_outputs());
for (size_t i = 0; i < ograd.size(); ++i) {
ograd[i].index = static_cast<uint32_t>(i);
}
// input gradient list
auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
const Op* backward_op = inode.source->op();
const Node* igrad_node = nullptr;
// Input gradient assignement
for (size_t i = 0; i < igrad.size(); ++i) {
if (igrad[i].node->op() == backward_op) {
if (igrad[i].node->op() == inode.source->op()) {
uint32_t eid = idx.entry_id(nid, igrad[i].index);
if (fis_none(rshape[eid])) {
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
Expand All @@ -120,6 +120,8 @@ Graph InferAttr(Graph &&ret,
}
}
// out grad entries
CHECK(igrad_node != nullptr)
<< "Cannot find matching backward op for " << inode.source->attrs.name;
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
const NodeEntry& e = igrad_node->inputs[i];
if (e.node == nullptr) {
Expand Down

0 comments on commit ddf3c17

Please sign in to comment.