Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Fully connected, higher order grad #14779

Merged
merged 33 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5bda12f
Add backward to fully connected. (_backward_FullyConnected)
larroy Apr 23, 2019
43babec
NodeEntry refactor fallout
larroy May 24, 2019
399d973
Add comment on correctness of 2nd order grad
larroy May 24, 2019
6289bfc
simplify implementation of second order gradient for FC
larroy May 28, 2019
e1d5989
Address CR comments
larroy Jun 18, 2019
a3df6fa
Address CR comments
larroy Jun 18, 2019
dc14f93
Python2 compat
larroy Jun 19, 2019
e3a32a8
Code review
larroy Jun 26, 2019
cff1628
Code review
larroy Jun 26, 2019
d58a225
sketch of 2nd order gradient for FC, shapes don't work
larroy Jul 3, 2019
26b5b20
Add its own operator
larroy Jul 17, 2019
94093c7
Save work
larroy Jul 17, 2019
96ee98e
implementation of FCGradGrad
larroy Jul 17, 2019
02499ff
fix enums
larroy Jul 19, 2019
e3e9858
refine FC higher order gradient test
larroy Jul 19, 2019
501918a
Lint and CR
larroy Jul 19, 2019
bcd14f3
CR
larroy Jul 19, 2019
8dc3efd
CR
larroy Jul 19, 2019
9d0ed3d
CR comments
larroy Jul 19, 2019
31eb047
Refactor camelcase
larroy Jul 20, 2019
04994fc
Fix typo
larroy Jul 22, 2019
57f71e1
Add bias 2nd order
larroy Jul 23, 2019
9a8de93
Fix FC flatten test
larroy Jul 23, 2019
8ee0b0d
Fix and test FC without flattened inputs
larroy Jul 23, 2019
983c1c4
Clean imports
larroy Jul 23, 2019
8165ef5
refine comment
larroy Jul 24, 2019
14f1863
CR
larroy Jul 30, 2019
256dfb2
CR
larroy Aug 20, 2019
664f7a1
merge with master
larroy Sep 25, 2019
99fa191
Fix grammar
larroy Sep 25, 2019
0f70b99
Build fix
larroy Sep 28, 2019
441f499
Merge remote-tracking branch 'upstream/master' into fc_higher_order_grad
larroy Sep 28, 2019
18e3226
Fix submodules
larroy Sep 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class NDArray;
* \brief tensor blob class that can be used to hold tensor of any dimension,
* any device and any data type,
* This is a weak type that can be used to transfer data through interface
* TBlob itself do not involve any arithmentic operations,
* TBlob itself doesn't involve any arithmetic operations,
* but it can be converted to tensor of fixed dimension for further operations
*
* Like tensor, this data structure is like a pointer class and do not
Expand Down
1 change: 1 addition & 0 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ using namespace mshadow;
// CPU/GPU-versions of BLAS3 function "gemm". Please refer to the BLAS3-documentation
// for further information about the function and its parameters.
// Note that this is C = gemm(A,B,C), so C is input and output parameter.
// C = alpha * A * B + beta * C
template<typename xpu, typename DType>
void linalg_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B,
const Tensor<xpu, 2, DType>& C, DType alpha, DType beta,
Expand Down
205 changes: 167 additions & 38 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>
#include <string>
#include <utility>
#include <limits>
#include <algorithm>
#include "../operator_common.h"
#include "../elemwise_op_common.h"
Expand All @@ -48,7 +49,11 @@ namespace fullc {
enum FullyConnectedOpInputs {kData, kWeight, kBias};
enum FullyConnectedOpResource {kTempSpace};
enum FullyConnectedOpOutputs {kOut};
} // fullc
enum FullyConnectedGradGradOutputs { kOyGrad, kXGradGrad, kWGradGrad, kBGradGrad };
enum GradGradInputs { kOxGrad, kOwGrad, };
enum GradGradInputsBias { kObGrad = 2, kOyBias, };
enum GradGradInputsNoBias { kOy = 2, };
} // namespace fullc

namespace quantized_fullc {
enum QuantizedFCInputMinMax {kDataMin, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax};
Expand Down Expand Up @@ -77,6 +82,38 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
}
};

/**
* Flatten additional dimensions after the first
* @tparam xpu
* @tparam DType
* @param tblob
* @param ctx
* @return 2 Dimensional Tensor with upper shapes collapsed
*/
template<typename xpu, typename DType>
Tensor<xpu, 2, DType> FlattenAs2DTail(const TBlob& tblob, const OpContext& ctx) {
const TShape& shape = tblob.shape_;
Stream<xpu> *stream = ctx.get_stream<xpu>();
return tblob.get_with_shape<xpu, 2, DType>(
Shape2(shape[0], shape.ProdShape(1, shape.ndim())), stream);
}

/**
* Flatten dimensions except last
* @tparam xpu
* @tparam DType
* @param tblob
* @param ctx
* @return 2 Dimensional tensor with front shapes collapsed
*/
template<typename xpu, typename DType>
Tensor<xpu, 2, DType> FlattenAs2DHead(const TBlob& tblob, const OpContext& ctx) {
const TShape& shape = tblob.shape_;
Stream<xpu> *stream = ctx.get_stream<xpu>();
return tblob.get_with_shape<xpu, 2, DType>(
Shape2(shape.ProdShape(0, shape.ndim()-1), shape[shape.ndim()-1]), stream);
}

template<typename DType>
void AddBias(Tensor<cpu, 1, DType> bias, Tensor<cpu, 2, DType> data,
Tensor<cpu, 2, DType> out, Stream<cpu>*) {
Expand Down Expand Up @@ -153,21 +190,14 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif // __CUDACC__
const mxnet::TShape& ishape = in_data[fullc::kData].shape_;
larroy marked this conversation as resolved.
Show resolved Hide resolved
const mxnet::TShape& oshape = out_data[fullc::kOut].shape_;

Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> data, out;
if (!param.flatten) {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
data = FlattenAs2DHead<xpu, DType>(in_data[fullc::kData], ctx);
out = FlattenAs2DHead<xpu, DType>(out_data[fullc::kOut], ctx);
} else {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
out = out_data[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
data = FlattenAs2DTail<xpu, DType>(in_data[fullc::kData], ctx);
out = FlattenAs2DTail<xpu, DType>(out_data[fullc::kOut], ctx);
}

CHECK_EQ(data.shape_[1], wmat.shape_[1])
Expand Down Expand Up @@ -339,47 +369,38 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
using namespace mshadow::expr;
// TODO(bing): check the BLAS Handle, be careful
// maybe need blas handle from context
Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& ishape = in_data[fullc::kData].shape_;
const mxnet::TShape& oshape = out_grad[fullc::kOut].shape_;

Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> data, grad, gdata;
Stream<xpu> *stream = ctx.get_stream<xpu>();
larroy marked this conversation as resolved.
Show resolved Hide resolved
Tensor<xpu, 2, DType> wmat = in_data[fullc::kWeight].get<xpu, 2, DType>(stream);
Tensor<xpu, 2, DType> x, y_grad, x_grad;
if (!param.flatten) {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape.ProdShape(0, ishape.ndim()-1), ishape[ishape.ndim()-1]), s);
x = FlattenAs2DHead<xpu, DType>(in_data[fullc::kData], ctx);
y_grad = FlattenAs2DHead<xpu, DType>(out_grad[fullc::kOut], ctx);
x_grad = FlattenAs2DHead<xpu, DType>(in_grad[fullc::kData], ctx);
} else {
data = in_data[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
grad = out_grad[fullc::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);
gdata = in_grad[fullc::kData].get_with_shape<xpu, 2, DType>(
Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s);
x = FlattenAs2DTail<xpu, DType>(in_data[fullc::kData], ctx);
y_grad = FlattenAs2DTail<xpu, DType>(out_grad[fullc::kOut], ctx);
x_grad = FlattenAs2DTail<xpu, DType>(in_grad[fullc::kData], ctx);
}

#if defined(__CUDACC__)
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
CHECK_EQ(stream->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif
// backprop
CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
// gradient of weight
Tensor<xpu, 2, DType> gwmat = in_grad[fullc::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> w_grad = in_grad[fullc::kWeight].get<xpu, 2, DType>(stream);
// Legacy approach shown here for comparison:
// out = Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data));
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
// out = Assign(w_grad, req[fullc::kWeight], dot(grad.T(), data));
linalg_gemm(y_grad, x, w_grad, true, false, stream, req[fullc::kWeight]);
// gradient of bias
if (!param.no_bias) {
AddBiasGrad(in_grad[fullc::kBias], grad, req[fullc::kBias], param.num_hidden, ctx);
AddBiasGrad(in_grad[fullc::kBias], y_grad, req[fullc::kBias], param.num_hidden, ctx);
}
// gradient of data
// Legacy approach shown here for comparison:
// Assign(gdata, req[fullc::kData], dot(grad, wmat));
linalg_gemm(grad, wmat, gdata, false, false, s, req[fullc::kData]);
// Assign(x_grad, req[fullc::kData], dot(y_grad, wmat));
linalg_gemm(y_grad, wmat, x_grad, false, false, stream, req[fullc::kData]);
}

template<typename xpu>
Expand Down Expand Up @@ -418,7 +439,7 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(inputs.size(), 3U); // ograd_y, x, w
CHECK_EQ(outputs.size(), out_expected);
CHECK_EQ(req.size(), out_expected);

Expand All @@ -442,6 +463,114 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
}
}



///
// Inputs are:
// o_x_grad : head gradient for x_grad
// o_w_grad : head gradient for w_grad
// o_b_grad : if param.no_bias is false
// o_y : head gradient of y
//
// outputs are:
// o_y_grad : gradient of o_y
// x_grad_grad : o_y * o_w_grad
// w_grad_grad : o_y.T * o_x_grad
// b_grad_grad: if param.no_bias is false
//
// For implementation details see this PR: https://github.com/apache/incubator-mxnet/pull/14779

/**
* Second order gradient for Fully Connected
* x_grad_grad = o_y * o_w_grad
* w_grad_grad = o_y.T * o_x_grad
*
* @tparam xpu
* @tparam DType
* @param attrs
* @param ctx
* @param inputs
* @param req
* @param outputs
*/
template<typename xpu, typename DType>
void FullyConnectedGradGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace std;
using namespace fullc;
Stream<xpu> *stream = ctx.get_stream<xpu>();
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
const size_t num_inputs = param.no_bias ? 3U : 4U;
// outputs are: o_x_grad, o_w_grad, o_y || o_x_grad, o_w_grad, o_b_grad, o_y
const size_t num_outputs = 3U;
CHECK_EQ(inputs.size(), num_inputs);
CHECK_EQ(outputs.size(), num_outputs);
CHECK_EQ(req.size(), num_outputs);

// inputs
Tensor<xpu, 2, DType> o_x_grad;
Tensor<xpu, 2, DType> o_w_grad;
Tensor<xpu, 2, DType> o_y;
// unused
// Tensor<xpu, 1, DType> o_b_grad;

// outputs
Tensor<xpu, 2, DType> o_y_grad;
TBlob o_y_grad_blob = outputs[kOyGrad];
Tensor<xpu, 2, DType> x_grad_grad;
Tensor<xpu, 2, DType> w_grad_grad;
Tensor<xpu, 1, DType> b_grad_grad;
size_t o_y_idx = std::numeric_limits<size_t>::max();
if (param.no_bias)
o_y_idx = kOy;
else
o_y_idx = kOyBias;
if (!param.flatten) {
o_x_grad = FlattenAs2DHead<xpu, DType>(inputs[kOxGrad], ctx);
o_w_grad = inputs[kOwGrad].get<xpu, 2, DType>(stream);
o_y = FlattenAs2DHead<xpu, DType>(inputs[o_y_idx], ctx);
x_grad_grad = FlattenAs2DHead<xpu, DType>(outputs[kXGradGrad], ctx);
w_grad_grad = FlattenAs2DHead<xpu, DType>(outputs[kWGradGrad], ctx);
} else {
o_x_grad = FlattenAs2DTail<xpu, DType>(inputs[kOxGrad], ctx);
o_w_grad = FlattenAs2DTail<xpu, DType>(inputs[kOwGrad], ctx);
o_y = inputs[o_y_idx].get<xpu, 2, DType>(stream);
x_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kXGradGrad], ctx);
w_grad_grad = FlattenAs2DTail<xpu, DType>(outputs[kWGradGrad], ctx);
}
linalg_gemm(o_y, o_w_grad, x_grad_grad, false, false, stream, req[kXGradGrad]);
linalg_gemm(o_y, o_x_grad, w_grad_grad, true, false, stream, req[kWGradGrad]);
// 3rd order not supported
Fill(stream, o_y_grad_blob, kWriteTo, static_cast<DType>(0));
/* TODO(larroy) bias is not supported yet as there's no bias input to backward. Bias grad grad is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxjscience Could you please review if this is correct?

* zero.
if (!param.no_bias) {
// The second order gradient for b doesn't depend on x or w. Thus we set it to 0.
b_grad_grad = outputs.at(kBGradGrad).get<xpu, 1, DType>(stream);
TBlob b_grad_grad_blob = TBlob(b_grad_grad);
Fill(stream, b_grad_grad_blob, kWriteTo, static_cast<DType>(0));
}
*/
}


template<typename xpu>
void FullyConnectedGradGradDTypeDispatch(
const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const int dtype = inputs[0].type_flag_;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
FullyConnectedGradGradCompute<xpu, DType>(attrs, ctx, inputs, req, outputs);
});
}


larroy marked this conversation as resolved.
Show resolved Hide resolved
} // namespace op
} // namespace mxnet
namespace std {
Expand Down
56 changes: 46 additions & 10 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,21 @@ struct FullyConnectedGrad {
}
};

inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
struct FullyConnectedGradGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.push_back(n->inputs[0]); // o_y : head gradient of the output y
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};

static bool FCStorageType(const nnvm::NodeAttrs& attrs,
larroy marked this conversation as resolved.
Show resolved Hide resolved
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
const bool valid_data = in_attrs->at(0) == kDefaultStorage;
const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
Expand Down Expand Up @@ -210,11 +220,11 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
larroy marked this conversation as resolved.
Show resolved Hide resolved
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
Expand Down Expand Up @@ -324,6 +334,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", FullyConnectedGradGrad{"_backward_backward_FullyConnected"})
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
Expand All @@ -332,5 +343,30 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
#endif
.set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);

// 2nd gradient for fully connected
// Inputs are:
// o_x_grad : head gradient for x_grad
// o_w_grad : head gradient for w_grad
// o_b_grad : if param.no_bias is false
// o_y : head gradient of y
//
// outputs are:
// o_y_grad : not used
// x_grad_grad : o_w_grad * o_y^T
// w_grad_grad : o_x_grad * o_y
//
// For a detailed development of the second gradient see here: TODO(larroy)
NNVM_REGISTER_OP(_backward_backward_FullyConnected)
.set_num_inputs([](const NodeAttrs& attrs) {
const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
return params.no_bias ? 3 : 4;
})
.set_num_outputs([](const NodeAttrs& attrs) {
return 3;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<FullyConnectedParam>)
.set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradGradDTypeDispatch<cpu>);

} // namespace op
} // namespace mxnet
Loading