Skip to content

Commit

Permalink
[Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (apach…
Browse files Browse the repository at this point in the history
…e#18504)

* fix batch norm when fix_gamma is True

* support gradient accumulation for batch norm

* mkldnn batchnorm support grad add

* unittest for bn

* fix bn arg

* fix lint

* fix mkldnn

* fix mkldnn bn

* fix grad when fixing gamma

* fix naive gpu bn

* fix lint

* invoke mkldnn and cudnn batchnorm when axis != 1

* backport 18500

* change condition

* fix

* fix

* add mkldnn_off for bn

* remove mkldnn_off

* recover save_000800.json

* cast
  • Loading branch information
wkcn authored and chinakook committed Nov 17, 2020
1 parent bb9096f commit b8a1fbd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
12 changes: 8 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& !mxnet::op::batchnorm::disable_mkl;
if (mxnet::op::batchnorm::disable_mkl) return false;
const mxnet::TShape shape = input.shape();
const int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) return false;
const int dtype = input.dtype();
return (dtype == mshadow::kFloat32 ||
dtype == mshadow::kBfloat16) &&
SupportStorageMKLDNN(input.storage_type());
}

void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
Expand Down Expand Up @@ -727,8 +726,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
Expand Down
26 changes: 19 additions & 7 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,27 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
CHECK_GE(param_.axis, 0);
CHECK_LT(param_.axis, in_data.ndim());
if (param_.axis == 1) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2,
in_data.ndim()));
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
// reshape to (N, C, 1, D), C is the `param_.axis` dimension
shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
shape_[1] = in_data.shape_[param_.axis];
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
in_data.ndim()));
}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
44 changes: 39 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

mxnet::TShape shape = inputs[batchnorm::kData].shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
NDArray out = outputs[batchnorm::kOut];
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
out = out.Reshape(new_shape);
}

const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
mkldnn::normalization_flags flags = _GetFlags(in_data,
Expand All @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = outputs[batchnorm::kOut];

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
ctx.is_train && !param.use_global_stats,
fuse_relu);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
const NDArray &gradIn = in_grad[batchnorm::kData];
NDArray data = in_data[batchnorm::kData];
NDArray diff = out_grad[batchnorm::kOut];
NDArray gradIn = in_grad[batchnorm::kData];
const NDArray &moving_mean = aux_states[batchnorm::kMovingMean];
const NDArray &moving_var = aux_states[batchnorm::kMovingVar];
const NDArray &out_mean = out_data[batchnorm::kMean];
Expand All @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(moving_mean.IsDefaultData());
CHECK(moving_var.IsDefaultData());

mxnet::TShape shape = data.shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
data = data.Reshape(new_shape);
diff = diff.Reshape(new_shape);
gradIn = gradIn.Reshape(new_shape);
}

auto data_mem = data.GetMKLDNNData();
auto diff_mem = diff.GetMKLDNNData();
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
Expand Down

0 comments on commit b8a1fbd

Please sign in to comment.