Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 #18504

Merged
merged 21 commits into from
Jul 9, 2020
Merged
12 changes: 8 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are removing the check for ndim == 4 here, and another lighter check for ndim == 1 || ndim ==2 || ndim ==4 present in SupportMKLDNN.
Does that mean ndim can be anything >0 ? What are the allowed values for ndim?

Copy link
Member Author

@wkcn wkcn Jul 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ndim coulld be anything > 0.

If the shape of input A is shape, the input will be shaped into (prod(shape[0:axis]), shape[axis], 1, prod(shape[axis+1:len(shape)]) ).

ndim shape axis=0 axis=1 axis=2 axis=3 axis=4
1 (N,) (1, N, 1, 1) x x x x
2 (N,C) (1, N, 1, C) (N, C, 1, 1) x x x
3 (N,C,H) (1, N, 1, CH) (N, C, 1, H) (NC, H, 1, 1) x x
4 (N,C,H,W) (1, N, 1, CHW) (N, C, 1, HW) (NC, H, 1, W) (NCH, W, 1, 1) x
5 (N,D,C,H,W) (1, N, 1, DCHW) (N, D, 1, CHW) (ND,C, 1, HW) (NDC, H, 1, W) (NDCH, W, 1, 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Does the table continue further for ndims > 5? Or should we place a check for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it supports ndim > 5 too.

&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& !mxnet::op::batchnorm::disable_mkl;
if (mxnet::op::batchnorm::disable_mkl) return false;
const mxnet::TShape shape = input.shape();
const int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) return false;
const int dtype = input.dtype();
return (dtype == mshadow::kFloat32 ||
dtype == mshadow::kBfloat16) &&
SupportStorageMKLDNN(input.storage_type());
}

void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
Expand Down Expand Up @@ -727,8 +726,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
Expand Down
26 changes: 19 additions & 7 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,27 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
CHECK_GE(param_.axis, 0);
CHECK_LT(param_.axis, in_data.ndim());
if (param_.axis == 1) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(2,
in_data.ndim()));
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
// reshape to (N, C, 1, D), C is the `param_.axis` dimension
shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
shape_[1] = in_data.shape_[param_.axis];
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
in_data.ndim()));
}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
44 changes: 39 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

mxnet::TShape shape = inputs[batchnorm::kData].shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
NDArray out = outputs[batchnorm::kOut];
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
out = out.Reshape(new_shape);
}

const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
mkldnn::normalization_flags flags = _GetFlags(in_data,
Expand All @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = outputs[batchnorm::kOut];

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
ctx.is_train && !param.use_global_stats,
fuse_relu);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
const NDArray &gradIn = in_grad[batchnorm::kData];
NDArray data = in_data[batchnorm::kData];
NDArray diff = out_grad[batchnorm::kOut];
NDArray gradIn = in_grad[batchnorm::kData];
const NDArray &moving_mean = aux_states[batchnorm::kMovingMean];
const NDArray &moving_var = aux_states[batchnorm::kMovingVar];
const NDArray &out_mean = out_data[batchnorm::kMean];
Expand All @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(moving_mean.IsDefaultData());
CHECK(moving_var.IsDefaultData());

mxnet::TShape shape = data.shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
data = data.Reshape(new_shape);
diff = diff.Reshape(new_shape);
gradIn = gradIn.Reshape(new_shape);
}

auto data_mem = data.GetMKLDNNData();
auto diff_mem = diff.GetMKLDNNData();
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
Expand Down