Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #13 from mavenlin/master
Browse files Browse the repository at this point in the history
out_data is necessary, e.g. sigmoid
  • Loading branch information
antinucleon committed Aug 14, 2015
2 parents b5f75f1 + 4dd1924 commit 0cf889c
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op,
* \param op the operator handle
* \param grad_next array of output gradients
* \param in_data array of input narray to the operator
* \param out_data array of output narray to the operator
* \param out_grad array to holds the gradient on these input
* can be NULL if that position request is kNullOp
* \param reqs gradient request type
Expand All @@ -394,6 +395,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op,
MXNET_DLL int MXOpBackward(OperatorHandle op,
NArrayHandle *grad_next,
NArrayHandle *in_data,
NArrayHandle *out_data,
NArrayHandle *out_grad,
mx_uint *reqs);

Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class StaticOperator {
* \param ctx runtime context
* \param grad_next the gradient value we get from output of the StaticOperator
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient, there could be three possible TBlob
* in the each element in the array
* \param req request types of the gradient saving operation
Expand All @@ -62,6 +63,7 @@ class StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) = 0;
/*!
Expand Down Expand Up @@ -113,6 +115,7 @@ class Operator {
* \param ctx runtime context
* \param grad_next the gradient value of the output of the operator, used by chain rule.
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient
* \param req request types of the gradient saving operation
* only inplace will change input data
Expand All @@ -121,6 +124,7 @@ class Operator {
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req) = 0;
/*!
Expand Down
2 changes: 2 additions & 0 deletions src/operator/composite_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class CompositeOperator : public Operator {
* \param ctx runtime context
* \param grad_next the gradient value of the output of the operator, used by chain rule.
* \param in_data the array of input data
* \param out_data the array of output data
* \param out_grad array of output gradient
* \param req request types of the gradient saving operation
* only inplace will change input data
Expand All @@ -75,6 +76,7 @@ class CompositeOperator : public Operator {
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req);
/*!
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/activation_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ActivationOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/convolution_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ConvolutionOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
using namespace mshadow;
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/dropout_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DropoutOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/fully_connect_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class FullyConnectOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
using namespace mshadow;
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/pooling_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class PoolingOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/reshape_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ReshapeOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
10 changes: 6 additions & 4 deletions src/operator/static_operator_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,29 @@ class StaticOperatorWrapper: public Operator {
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req) {
std::vector<DAGEngine::Variable> used_var;
std::vector<DAGEngine::Variable> mutate_var;
std::vector<TBlob> grad_in;
std::vector<TBlob> grad_out;
std::vector<TBlob> data;
std::vector<TBlob> data_in;
std::vector<TBlob> data_out;
for (size_t i = 0; i < grad_next.size(); ++i) {
used_var.push_back(grad_next[i].var());
grad_in.push_back(grad_next[i].data());
}
for (size_t i = 0; i < in_data.size(); ++i) {
used_var.push_back(in_data[i].var());
data.push_back(in_data[i].data());
data_in.push_back(in_data[i].data());
}
for (size_t i = 0; i < out_grad.size(); ++i) {
mutate_var.push_back(out_grad[i].var());
grad_out.push_back(out_grad[i].data());
}
DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) {
op_->Backward(ctx, grad_in, data, grad_out, req);
DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data_in, data_out, req](RunContext ctx) {
op_->Backward(ctx, grad_in, data_in, data_out, grad_out, req);
}, ctx_, used_var, mutate_var);
}

Expand Down

0 comments on commit 0cf889c

Please sign in to comment.