From 4dd19249dde108d9ab2989ff94c331fef6e041c7 Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 12 Aug 2015 21:44:34 +0800 Subject: [PATCH] out_data is necessary, e.g. sigmoid --- include/mxnet/c_api.h | 2 ++ include/mxnet/operator.h | 4 ++++ src/operator/composite_operator.h | 2 ++ src/operator/static_operator/activation_op-inl.h | 1 + src/operator/static_operator/convolution_op-inl.h | 1 + src/operator/static_operator/dropout_op-inl.h | 1 + src/operator/static_operator/fully_connect_op-inl.h | 1 + src/operator/static_operator/pooling_op-inl.h | 1 + src/operator/static_operator/reshape_op-inl.h | 1 + src/operator/static_operator_wrapper.cc | 10 ++++++---- 10 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 29c9691e8ff5..bb718b6f9fdb 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -385,6 +385,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, * \param op the operator handle * \param grad_next array of output gradients * \param in_data array of input narray to the operator + * \param out_data array of output narray to the operator * \param out_grad array to holds the gradient on these input * can be NULL if that position request is kNullOp * \param reqs gradient request type @@ -394,6 +395,7 @@ MXNET_DLL int MXOpForward(OperatorHandle op, MXNET_DLL int MXOpBackward(OperatorHandle op, NArrayHandle *grad_next, NArrayHandle *in_data, + NArrayHandle *out_data, NArrayHandle *out_grad, mx_uint *reqs); diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index c1a53df61fa9..1a8f6fac97a1 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -53,6 +53,7 @@ class StaticOperator { * \param ctx runtime context * \param grad_next the gradient value we get from output of the StaticOperator * \param in_data the array of input data + * \param out_data the array of output data * \param out_grad array of output gradient, there could be three possible TBlob * in the each element in the array * \param req request types of the gradient saving operation @@ -62,6 +63,7 @@ class StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) = 0; /*! @@ -113,6 +115,7 @@ class Operator { * \param ctx runtime context * \param grad_next the gradient value of the output of the operator, used by chain rule. * \param in_data the array of input data + * \param out_data the array of output data * \param out_grad array of output gradient * \param req request types of the gradient saving operation * only inplace will change input data @@ -121,6 +124,7 @@ class Operator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) = 0; /*! diff --git a/src/operator/composite_operator.h b/src/operator/composite_operator.h index 12297dc41c43..d80c32879422 100644 --- a/src/operator/composite_operator.h +++ b/src/operator/composite_operator.h @@ -67,6 +67,7 @@ class CompositeOperator : public Operator { * \param ctx runtime context * \param grad_next the gradient value of the output of the operator, used by chain rule. * \param in_data the array of input data + * \param out_data the array of output data * \param out_grad array of output gradient * \param req request types of the gradient saving operation * only inplace will change input data @@ -75,6 +76,7 @@ class CompositeOperator : public Operator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req); /*! diff --git a/src/operator/static_operator/activation_op-inl.h b/src/operator/static_operator/activation_op-inl.h index c888b35c9c61..cfb0b7cec8b5 100644 --- a/src/operator/static_operator/activation_op-inl.h +++ b/src/operator/static_operator/activation_op-inl.h @@ -39,6 +39,7 @@ class ActivationOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); diff --git a/src/operator/static_operator/convolution_op-inl.h b/src/operator/static_operator/convolution_op-inl.h index 2271839b697a..fc9b3369f2a6 100644 --- a/src/operator/static_operator/convolution_op-inl.h +++ b/src/operator/static_operator/convolution_op-inl.h @@ -134,6 +134,7 @@ class ConvolutionOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { using namespace mshadow; diff --git a/src/operator/static_operator/dropout_op-inl.h b/src/operator/static_operator/dropout_op-inl.h index b79a79fbea65..23c9f6aab457 100644 --- a/src/operator/static_operator/dropout_op-inl.h +++ b/src/operator/static_operator/dropout_op-inl.h @@ -59,6 +59,7 @@ class DropoutOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); diff --git a/src/operator/static_operator/fully_connect_op-inl.h b/src/operator/static_operator/fully_connect_op-inl.h index d39335deeeff..62e94f00ae36 100644 --- a/src/operator/static_operator/fully_connect_op-inl.h +++ b/src/operator/static_operator/fully_connect_op-inl.h @@ -56,6 +56,7 @@ class FullyConnectOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { using namespace mshadow; diff --git a/src/operator/static_operator/pooling_op-inl.h b/src/operator/static_operator/pooling_op-inl.h index db5e40ffb4a7..8c6014a8c2cf 100644 --- a/src/operator/static_operator/pooling_op-inl.h +++ b/src/operator/static_operator/pooling_op-inl.h @@ -88,6 +88,7 @@ class PoolingOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); diff --git a/src/operator/static_operator/reshape_op-inl.h b/src/operator/static_operator/reshape_op-inl.h index 44d8f8fcef24..ba966a62a29f 100644 --- a/src/operator/static_operator/reshape_op-inl.h +++ b/src/operator/static_operator/reshape_op-inl.h @@ -52,6 +52,7 @@ class ReshapeOp : public StaticOperator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { CHECK_EQ(grad_next.size(), 1); diff --git a/src/operator/static_operator_wrapper.cc b/src/operator/static_operator_wrapper.cc index afd4bae6241c..1690d067c6e6 100644 --- a/src/operator/static_operator_wrapper.cc +++ b/src/operator/static_operator_wrapper.cc @@ -55,27 +55,29 @@ class StaticOperatorWrapper: public Operator { virtual void Backward(RunContext ctx, const std::vector &grad_next, const std::vector &in_data, + const std::vector &out_data, const std::vector &out_grad, const std::vector &req) { std::vector used_var; std::vector mutate_var; std::vector grad_in; std::vector grad_out; - std::vector data; + std::vector data_in; + std::vector data_out; for (size_t i = 0; i < grad_next.size(); ++i) { used_var.push_back(grad_next[i].var()); grad_in.push_back(grad_next[i].data()); } for (size_t i = 0; i < in_data.size(); ++i) { used_var.push_back(in_data[i].var()); - data.push_back(in_data[i].data()); + data_in.push_back(in_data[i].data()); } for (size_t i = 0; i < out_grad.size(); ++i) { mutate_var.push_back(out_grad[i].var()); grad_out.push_back(out_grad[i].data()); } - DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data, req](RunContext ctx) { - op_->Backward(ctx, grad_in, data, grad_out, req); + DAGEngine::Get()->Push([this, ctx, grad_in, grad_out, data_in, data_out, req](RunContext ctx) { + op_->Backward(ctx, grad_in, data_in, data_out, grad_out, req); }, ctx_, used_var, mutate_var); }