Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

update support MKLDNN BN conditions #15870

Merged
merged 11 commits into from
Aug 23, 2019
Merged
10 changes: 4 additions & 6 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &p
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& shape[param.axis] % 8 == 0
&& !mxnet::op::batchnorm::disable_mkl;
}

Expand All @@ -395,8 +394,8 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {

if (SupportMKLDNNBN(inputs[0], param)) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the comments in L397

std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());

Expand All @@ -419,9 +418,8 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

mxnet::TShape shape = inputs[0].shape();
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param)
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {

if (SupportMKLDNNBN(inputs[0], param)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we don't need to check these two conditions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same reason as line 399, as described in Performance part, when a model has fp32 bn along with other int8 ops, it triggers the IsMKLDNNData() check and cannot run int MKLDNN, which cause a slower speed.

std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
Expand Down
21 changes: 11 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ using mkldnn::forward_inference;

inline static unsigned _GetFlags(const std::vector<NDArray> &in_data,
const std::vector<NDArray> &aux_states,
const BatchNormParam &param, bool is_train) {
const BatchNormParam &param, bool is_train_and_not_global_stats) {
unsigned flags = 0U;
if (in_data.size() == 3U) {
flags |= use_scale_shift;
}

// aux_states[0]: inMean
// aux_states[1]: inVariance
if (aux_states.size() == 2U && !is_train) {
if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
flags |= use_global_stats;
}
return flags;
Expand Down Expand Up @@ -107,13 +107,13 @@ class MKLDNNBNForward {
std::shared_ptr<const mkldnn::memory> mean_m;
std::shared_ptr<const mkldnn::memory> var_m;
std::shared_ptr<mkldnn::batch_normalization_forward> fwd;
bool is_train;
bool is_train_and_not_global_stats;
t_bn_f_pdesc pd;

public:
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train): pd(_pd) {
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) {
weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc()));
this->is_train = is_train;
this->is_train_and_not_global_stats = is_train_and_not_global_stats;
}

const mkldnn::memory &GetWeight() const {
Expand Down Expand Up @@ -161,7 +161,7 @@ class MKLDNNBNForward {
}

if (fwd == nullptr) {
if (!is_train)
if (!is_train_and_not_global_stats)
fwd.reset(new mkldnn::batch_normalization_forward(
pd, *data_m, mkldnn::primitive::at(*mean_m),
mkldnn::primitive::at(*var_m), *weight_m, *out_m));
Expand Down Expand Up @@ -194,13 +194,14 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
#endif
MKLDNNBNSignature key(param);
key.AddSign(ctx.is_train);
key.AddSign(param.use_global_stats);
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
(DType) param.eps, flags);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
Expand All @@ -213,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
const std::vector<NDArray> &out_data,
const std::vector<NDArray> &aux_states) {
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];

auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
Expand Down Expand Up @@ -253,7 +254,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
}
}

if (!ctx.is_train) {
if (!ctx.is_train || param.use_global_stats) {
DType* omean = out_data[batchnorm::kMean].data().dptr<DType>();
DType* ovar = out_data[batchnorm::kVar].data().dptr<DType>();
DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr<DType>();
Expand Down Expand Up @@ -378,7 +379,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,7 @@ def test_bilinear_upsampling():
@with_seed()
def test_batchnorm_training():
def check_batchnorm_training(stype):
for shape in [(2, 3), (2, 3, 2, 2)]:
for shape in [(2, 3), (2, 3, 2, 2), (2, 8, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
Expand Down Expand Up @@ -1821,7 +1821,7 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var):
bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol)

for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]:
for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]:
for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]:
for axis in range(len(shape)):
for cudnn_off in [False, True]:
for output_mean_var in [False, True]:
Expand Down