Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 14, 2015
2 parents 270bfb2 + 0cf889c commit b1127a7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 4 deletions.
8 changes: 8 additions & 0 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class StaticOperator {
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad) = 0;
};

#if DMLC_USE_CXX11
/*!
* \brief Operator interface.
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/activation_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ActivationOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/convolution_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ConvolutionOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
using namespace mshadow;
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/pooling_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class PoolingOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
1 change: 1 addition & 0 deletions src/operator/static_operator/reshape_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ReshapeOp : public StaticOperator {
virtual void Backward(RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad,
const std::vector<GradReqType> &req) {
CHECK_EQ(grad_next.size(), 1);
Expand Down
10 changes: 6 additions & 4 deletions src/operator/static_operator_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,29 @@ class StaticOperatorWrapper: public Operator {
virtual void Backward(RunContext ctx,
const std::vector<NArray> &grad_next,
const std::vector<NArray> &in_data,
const std::vector<NArray> &out_data,
const std::vector<NArray> &out_grad,
const std::vector<GradReqType> &req) {
std::vector<DAGEngine::Variable> used_var;
std::vector<DAGEngine::Variable> mutate_var;
std::vector<TBlob> grad_in;
std::vector<TBlob> grad_out;
std::vector<TBlob> data;
std::vector<TBlob> data_in;
std::vector<TBlob> data_out;
for (size_t i = 0; i < grad_next.size(); ++i) {
used_var.push_back(grad_next[i].var());
grad_in.push_back(grad_next[i].data());
}
for (size_t i = 0; i < in_data.size(); ++i) {
used_var.push_back(in_data[i].var());
data.push_back(in_data[i].data());
data_in.push_back(in_data[i].data());
}
for (size_t i = 0; i < out_grad.size(); ++i) {
mutate_var.push_back(out_grad[i].var());
grad_out.push_back(out_grad[i].data());
}
DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) {
op_->Backward(ctx, grad_in, data, grad_out, req);
DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data_in, data_out, req](RunContext ctx) {
op_->Backward(ctx, grad_in, data_in, data_out, grad_out, req);
}, ctx_, used_var, mutate_var);
}

Expand Down

0 comments on commit b1127a7

Please sign in to comment.