Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 #18504

Merged
merged 21 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
bool output_mean_var;
int axis;
bool cudnn_off;
bool mkldnn_off;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
Expand All @@ -96,6 +97,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
.describe("Specify which shape axis the channel is specified");
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
.describe("Do not select CUDNN operator, if available");
DMLC_DECLARE_FIELD(mkldnn_off).set_default(false)
.describe("Do not select MKLDNN operator, if available");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
Expand All @@ -116,6 +119,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var && this->axis == other.axis &&
this->cudnn_off == other.cudnn_off &&
this->mkldnn_off == other.mkldnn_off &&
this->min_calib_range.has_value() == other.min_calib_range.has_value() &&
this->max_calib_range.has_value() == other.max_calib_range.has_value();
if (this->min_calib_range.has_value() && other.min_calib_range.has_value() &&
Expand Down
12 changes: 8 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are removing the check for ndim == 4 here, and another lighter check for ndim == 1 || ndim ==2 || ndim ==4 present in SupportMKLDNN.
Does that mean ndim can be anything >0 ? What are the allowed values for ndim?

Copy link
Member Author

@wkcn wkcn Jul 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ndim coulld be anything > 0.

If the shape of input A is shape, the input will be shaped into (prod(shape[0:axis]), shape[axis], 1, prod(shape[axis+1:len(shape)]) ).

ndim shape axis=0 axis=1 axis=2 axis=3 axis=4
1 (N,) (1, N, 1, 1) x x x x
2 (N,C) (1, N, 1, C) (N, C, 1, 1) x x x
3 (N,C,H) (1, N, 1, CH) (N, C, 1, H) (NC, H, 1, 1) x x
4 (N,C,H,W) (1, N, 1, CHW) (N, C, 1, HW) (NC, H, 1, W) (NCH, W, 1, 1) x
5 (N,D,C,H,W) (1, N, 1, DCHW) (N, D, 1, CHW) (ND,C, 1, HW) (NDC, H, 1, W) (NDCH, W, 1, 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Does the table continue further for ndims > 5? Or should we place a check for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it supports ndim > 5 too.

&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& !mxnet::op::batchnorm::disable_mkl;
if (mxnet::op::batchnorm::disable_mkl || param.mkldnn_off) return false;
const mxnet::TShape shape = input.shape();
const int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) return false;
const int dtype = input.dtype();
return (dtype == mshadow::kFloat32 ||
dtype == mshadow::kBfloat16) &&
SupportStorageMKLDNN(input.storage_type());
}

void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
Expand Down Expand Up @@ -727,8 +726,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
Expand Down
25 changes: 18 additions & 7 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,26 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
CHECK_GE(param_.axis, 0);
CHECK_LT(param_.axis, in_data.ndim());
if (param_.axis == 1) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
// reshape to (N, C, 1, D), C is the `param_.axis` dimension
shape_[0] = static_cast<dim_t>(in_data.shape_.ProdShape(0, param_.axis));
shape_[1] = in_data.shape_[param_.axis];
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
shape_[3] = static_cast<dim_t>(in_data.shape_.ProdShape(param_.axis + 1,
static_cast<int>(in_data.ndim())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need static_cast here? why cant we do it like line 276?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return dtype of ProdShape is size_t, a unsigned int, but the dtype of shape_[i] is dim_t, namely int64_t.
I do not know whether the static_cast is neceseary to avoid the potential compiler's warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add static_cast to line 276 as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_data.ndims() shouldn’t need a static cast though, right?

Copy link
Member Author

@wkcn wkcn Jul 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the dtype of .ndim() is int32_t.

variable dtype
axis int32_t
ProdShape size_t
shape_[i] dim_t, int64_t
ndim() int32_t

The signature of ProdShape is size_t ProdShape(int dimstart, int dimend) const.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mseth10 , do you have any suggestion about whether to use static_cast ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo you can ignore the compiler warnings (if any) for int32_t to int,
you can keep static_cast for size_t to dim_t

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I have updated it : )

}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
44 changes: 39 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs, bool fuse_relu) {
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

mxnet::TShape shape = inputs[batchnorm::kData].shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
NDArray out = outputs[batchnorm::kOut];
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
out = out.Reshape(new_shape);
}

const std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
mkldnn::normalization_flags flags = _GetFlags(in_data,
Expand All @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
fuse_relu);
const NDArray &data = in_data[batchnorm::kData];
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = outputs[batchnorm::kOut];

// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
Expand Down Expand Up @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
ctx.is_train && !param.use_global_stats,
fuse_relu);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
const NDArray &gradIn = in_grad[batchnorm::kData];
NDArray data = in_data[batchnorm::kData];
NDArray diff = out_grad[batchnorm::kOut];
NDArray gradIn = in_grad[batchnorm::kData];
const NDArray &moving_mean = aux_states[batchnorm::kMovingMean];
const NDArray &moving_var = aux_states[batchnorm::kMovingVar];
const NDArray &out_mean = out_data[batchnorm::kMean];
Expand All @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(moving_mean.IsDefaultData());
CHECK(moving_var.IsDefaultData());

mxnet::TShape shape = data.shape();
const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
CHECK_LT(real_axis, shape.ndim());
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
data = data.Reshape(new_shape);
diff = diff.Reshape(new_shape);
gradIn = gradIn.Reshape(new_shape);
}

auto data_mem = data.GetMKLDNNData();
auto diff_mem = diff.GetMKLDNNData();
// MKLDNN batchnorm should run on special layouts. If one of them isn't, we
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/save_000800.json
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@
"eps": "0.001",
"fix_gamma": "True",
"momentum": "0.9",
"use_global_stats": "False"
"use_global_stats": "False",
"mkldnn_off": "True"
},
"name": "batchnorm0",
"inputs": [[11, 0], [12, 0], [13, 0]],
Expand Down Expand Up @@ -185,4 +186,4 @@
],
"arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15],
"heads": [[16, 0]]
}
}
2 changes: 1 addition & 1 deletion tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_load_000800():
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0')
fc3 = mx.symbol.BatchNorm(fc3, mkldnn_off=True, name='batchnorm0')
sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')

curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
Expand Down