Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix blockgrad (#5678)
Browse files Browse the repository at this point in the history
add zero prop

add test

fix

fix

fix

fix
  • Loading branch information
piiswrong authored Apr 5, 2017
1 parent 5c69214 commit 9a65a2e
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 87 deletions.
2 changes: 1 addition & 1 deletion nnvm
29 changes: 20 additions & 9 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,31 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
static const Op* ewise_sum_op = Op::Get("ElementWiseSum");
static const Op* identity_op = Op::Get("identity");
static const Op* zeros_op = Op::Get("_zeros");
// remove zero in the sum.
static const Op* zeros_like_op = Op::Get("zeros_like");

if (v.size() == 0) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = zeros_op;
ng->attrs.name = "zeros";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
}

// remove zero in the sum. at least keep 1.
size_t begin = 0;
for (size_t i = 0; i < v.size(); ++i) {
if (v[i].node->op() != zeros_op) {
if (v[i].node->op() != zeros_op && v[i].node->op() != zeros_like_op) {
if (begin != i) {
v[begin] = std::move(v[i]);
}
++begin;
}
}
if (begin == 0) begin = 1;
v.resize(begin);

if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = zeros_op;
ng->attrs.name = "zeros";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
} else {
if (v.size() < inplace_sum_cap) {
nnvm::NodePtr sum_node = nnvm::Node::Create();
Expand Down Expand Up @@ -216,10 +221,16 @@ nnvm::Graph GraphExecutor::InitFullGraph(
if (type == "CuDNNBatchNorm") return false;
return true;
};

std::vector<const nnvm::Op*> zero_ops;
zero_ops.push_back(nnvm::Op::Get("zeros_like"));
zero_ops.push_back(nnvm::Op::Get("_zeros"));

// take gradient
nnvm::Graph g_grad = nnvm::pass::Gradient(
g, symbol.outputs, xs, head_grad_entry_,
AggregateGradient, need_mirror);
AggregateGradient, need_mirror, nullptr,
zero_ops);
CHECK_EQ(g_grad.outputs.size(), xs.size());
for (const auto &e : g_grad.outputs) {
g.outputs.push_back(e);
Expand Down
30 changes: 28 additions & 2 deletions src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,39 @@
* \file initialize.cc
* \brief initialize mxnet library
*/
#include <signal.h>
#include <dmlc/logging.h>
#include <mxnet/engine.h>

#include "engine/profiler.h"

namespace mxnet {

void segfault_logger(int sig) {
const int MAX_STACK_SIZE = 10;
void *stack[MAX_STACK_SIZE];

fprintf(stderr, "\nSegmentation fault: %d\n\n", sig);

#if DMLC_LOG_STACK_TRACE
int nframes = backtrace(stack, MAX_STACK_SIZE);
fprintf(stderr, "Stack trace returned %d entries:\n", nframes);
char **msgs = backtrace_symbols(stack, nframes);
if (msgs != nullptr) {
for (int i = 0; i < nframes; ++i) {
fprintf(stderr, "[bt] (%d) %s\n", i, msgs[i]);
}
}
#endif // DMLC_LOG_STACK_TRACE

exit(1);
}

class LibraryInitializer {
public:
LibraryInitializer() {
dmlc::InitLogging("mxnet");
signal(SIGSEGV, segfault_logger);
#if MXNET_USE_PROFILER
// ensure profiler's constructor are called before atexit.
engine::Profiler::Get();
Expand All @@ -26,8 +48,12 @@ class LibraryInitializer {
});
#endif
}

static LibraryInitializer* Get();
};

static LibraryInitializer __library_init;
LibraryInitializer* LibraryInitializer::Get() {
static LibraryInitializer inst;
return &inst;
}
} // namespace mxnet

12 changes: 4 additions & 8 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ struct ElemwiseGradUseIn {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
for (auto& h : n->inputs) {
heads.push_back(h);
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
return MakeNonlossGradNode(op_name, n, ograds, n->inputs, n->attrs.dict);
}
};

Expand All @@ -87,12 +83,12 @@ struct ElemwiseGradUseOut {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
std::vector<nnvm::NodeEntry> heads;
index_t n_out = n->num_outputs();
for (index_t i = 0; i < n_out; ++i) {
heads.emplace_back(nnvm::NodeEntry{n, i, 0});
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
return MakeNonlossGradNode(op_name, n, ograds, heads, n->attrs.dict);
}
};

Expand All @@ -101,7 +97,7 @@ struct ElemwiseGradUseNone {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeGradNode(op_name, n, ograds, n->attrs.dict);
return MakeNonlossGradNode(op_name, n, ograds, {}, n->attrs.dict);
}
};

Expand Down
81 changes: 65 additions & 16 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,42 @@ inline bool type_assign(int *y, const int& x) {
#endif


// quick helper to make node
inline std::vector<nnvm::NodeEntry> MakeGradNode(
const char* op_name,
const nnvm::NodePtr& n,
std::vector<nnvm::NodeEntry> inputs,
std::unordered_map<std::string, std::string> dict) {
nnvm::NodePtr p = nnvm::Node::Create();
// make a new node with operator op_name. Inputs are not filled.
inline nnvm::NodePtr MakeNode(
const char* op_name, const std::string& name,
std::vector<nnvm::NodeEntry> const * inputs,
std::unordered_map<std::string, std::string> const * dict,
nnvm::NodePtr const * fwd_node) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get(op_name);
p->attrs.name = n->attrs.name + "_backward";
p->attrs.dict = std::move(dict);
p->attrs.name = name;
if (dict != nullptr) p->attrs.dict = *dict;
if (inputs != nullptr) p->inputs = *inputs;
if (fwd_node != nullptr) {
p->control_deps.emplace_back(*fwd_node);
}
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
p->control_deps.emplace_back(n);
p->inputs = std::move(inputs);
return p;
}

inline nnvm::NodePtr MakeNode(
const char* op_name, const std::string& name,
const std::vector<nnvm::NodeEntry>& inputs,
std::unordered_map<std::string, std::string> const * dict,
nnvm::NodePtr const * fwd_node) {
return MakeNode(op_name, name, &inputs, dict, fwd_node);
}


// quick helper to make node
inline std::vector<nnvm::NodeEntry> MakeGradNode(
const char* op_name, const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& inputs,
const std::unordered_map<std::string, std::string>& dict) {
auto p = MakeNode(op_name, n->attrs.name + "_backward",
&inputs, &dict, &n);
std::vector<nnvm::NodeEntry> ret;
for (index_t i = 0; i < p->num_outputs(); ++i) {
ret.emplace_back(nnvm::NodeEntry{p, i, 0});
Expand All @@ -199,22 +220,50 @@ inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
for (index_t i = 0; i < n->num_inputs(); ++i) {
nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_zeros");
std::ostringstream os;
if (1 == n->num_inputs()) {
os << n->attrs.name << "_backward";
} else {
os << n->attrs.name << "_in" << i << "_backward";
}
p->attrs.name = os.str();
p->attrs.dict = std::unordered_map<std::string, std::string>();
p->control_deps.emplace_back(n);
auto p = MakeNode("zeros_like", os.str(), {n->inputs[i]}, nullptr, &n);
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
}
return ret;
}


// check whether all output grads are zero.
inline bool CheckGradAllZero(const std::vector<nnvm::NodeEntry>& ograds) {
const auto zero_op = nnvm::Op::Get("_zeros");
const auto zero_like_op = nnvm::Op::Get("zeros_like");
if (!ograds.size()) return false;
for (const auto& grad : ograds) {
if (!grad.node) return false;
if (grad.node->op() != zero_op && grad.node->op() != zero_like_op ) return false;
}
return true;
}

// make gradient node that doesn't add to objective.
// i.e. igrads are always zero when ograds are zero.
inline std::vector<nnvm::NodeEntry> MakeNonlossGradNode(
const char* op_name, const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds,
const std::vector<nnvm::NodeEntry>& inputs,
const std::unordered_map<std::string, std::string> dict) {
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
auto p = MakeNode(op_name, n->attrs.name + "_backward",
nullptr, &dict, &n);
p->inputs.insert(p->inputs.end(), ograds.begin(), ograds.end());
p->inputs.insert(p->inputs.end(), inputs.begin(), inputs.end());
std::vector<nnvm::NodeEntry> ret;
for (index_t i = 0; i < p->num_outputs(); ++i) {
ret.emplace_back(nnvm::NodeEntry{p, i, 0});
}
return ret;
}

/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,9 @@ struct ReduceGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeGradNode(
return MakeNonlossGradNode(
op_name, n,
{ograds[0], n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
ograds, {n->inputs[0], nnvm::NodeEntry{n, 0, 0}},
n->attrs.dict);
}
};
Expand Down Expand Up @@ -646,8 +646,8 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
.set_attr<nnvm::FGradient>("FGradient", \
[](const nnvm::NodePtr& n, \
const std::vector<nnvm::NodeEntry>& ograds) { \
return MakeGradNode("_broadcast_backward", n, ograds, \
{{"keepdims", "true"}}); \
return MakeNonlossGradNode("_broadcast_backward", n, ograds, {}, \
{{"keepdims", "true"}}); \
}) \
.add_argument("data", "ndarray-or-symbol", "The input")

Expand Down
13 changes: 4 additions & 9 deletions src/operator/tensor/broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,11 @@ NNVM_REGISTER_OP(pick)
.set_attr<FCompute>("FCompute<cpu>", PickOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.push_back(n->inputs[1]);
auto ret = MakeGradNode("_backward_pick", n, heads, n->attrs.dict);

nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_zeros");
p->attrs.name = n->attrs.name + "_index_backward";
p->control_deps.emplace_back(n);
auto ret = MakeNonlossGradNode("_backward_pick", n, ograds,
{n->inputs[1]}, n->attrs.dict);
auto p = MakeNode("zeros_like", n->attrs.name + "_index_backward",
{n->inputs[1]}, nullptr, &n);
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});

return ret;
})
.add_argument("data", "NDArray", "Source input")
Expand Down
9 changes: 2 additions & 7 deletions src/operator/tensor/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,8 @@ NNVM_REGISTER_OP(where)
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> ret;
// make zero grad node for grad[condition]
nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_zeros");
std::ostringstream os;
os << n->attrs.name << "_in" << 0 << "_backward";
p->attrs.name = os.str();
p->attrs.dict = std::unordered_map<std::string, std::string>();
p->control_deps.emplace_back(n);
auto p = MakeNode("zeros_like", n->attrs.name + "_cond_backward",
{n->inputs[0]}, nullptr, &n);
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});

// make grad nodes for grad[x] and grad[y]
Expand Down
21 changes: 9 additions & 12 deletions src/operator/tensor/elemwise_unary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,8 @@ MXNET_OPERATOR_REGISTER_UNARY(make_loss)
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
nnvm::NodePtr p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_ones");
p->attrs.name = n->attrs.name + "_backward";
p->control_deps.emplace_back(n);
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
auto p = MakeNode("ones_like", n->attrs.name + "_backward",
&(n->inputs), nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
return ret;
Expand All @@ -60,16 +55,18 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); })
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
auto lhs = MakeGradNode("_backward_copy", n, ograds,
std::unordered_map<std::string, std::string>());
nnvm::NodePtr ng = nnvm::Node::Create();
ng->attrs.op = nnvm::Op::Get("_zeros");
ng->attrs.name = "zeros";
auto lhs = MakeNonlossGradNode(
"_backward_copy", n, ograds, {},
std::unordered_map<std::string, std::string>());
auto ng = MakeNode("zeros_like", n->attrs.name + "rhs_backward",
{n->inputs[1]}, nullptr, &n);
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
return lhs;
});
Expand Down
10 changes: 4 additions & 6 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.push_back(n->inputs[0]);
return MakeGradNode("_backward_Embedding", n, heads, n->attrs.dict);
return MakeNonlossGradNode("_backward_Embedding", n, ograds,
{n->inputs[0]}, n->attrs.dict);
})
.add_argument("data", "Symbol", "Input data to the EmbeddingOp.")
.add_argument("weight", "Symbol", "Embedding weight matrix.")
Expand Down Expand Up @@ -93,9 +92,8 @@ Examples::
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.push_back(n->inputs[1]);
return MakeGradNode("_backward_take", n, heads, n->attrs.dict);
return MakeNonlossGradNode("_backward_take", n, ograds,
{n->inputs[1]}, n->attrs.dict);
})
.add_argument("a", "ndarray-or-symbol", "The source array.")
.add_argument("indices", "ndarray-or-symbol", "The indices of the values to extract.")
Expand Down
Loading

0 comments on commit 9a65a2e

Please sign in to comment.