From ac58ff4bd8ffc2e273803c79cafcff38feb765c4 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Fri, 9 Aug 2019 14:37:29 +0800 Subject: [PATCH 01/11] update support MKLDNN BN conditions --- src/operator/nn/batch_norm.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 6382d46d272d..51e70c94d4ab 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -384,7 +384,6 @@ static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &p mxnet::TShape shape = input.shape(); return SupportMKLDNN(input) && shape.ndim() == 4 && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && shape[param.axis] % 8 == 0 && !mxnet::op::batchnorm::disable_mkl; } @@ -396,7 +395,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, CHECK_EQ(inputs.size(), 5U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); // MKLDNN batchnorm only works well on the special MKLDNN layout. - if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) { + if (SupportMKLDNNBN(inputs[0], param)) { std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); @@ -420,8 +419,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, mxnet::TShape shape = inputs[0].shape(); // MKLDNN batchnorm only works well on the special MKLDNN layout. - if (SupportMKLDNNBN(inputs[0], param) - && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { + if (SupportMKLDNNBN(inputs[0], param)) { std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); From 008e07e0c107927f643ef3981623de909bd5deaa Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 14 Aug 2019 10:47:11 +0800 Subject: [PATCH 02/11] add bn test case for channel size = 8 --- tests/python/unittest/test_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c360db9f01a5..466cee823029 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1626,7 +1626,7 @@ def test_bilinear_upsampling(): @with_seed() def test_batchnorm_training(): def check_batchnorm_training(stype): - for shape in [(2, 3), (2, 3, 2, 2)]: + for shape in [(2, 3), (2, 3, 2, 2), (2, 8, 2, 2)]: data_tmp = np.random.normal(-0.1, 0.1, size=shape) s = shape[1], gamma = np.ones(s) @@ -1821,7 +1821,7 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): for cudnn_off in [False, True]: for output_mean_var in [False, True]: From a83396795e45f0f66633dfb8d69ed81af5618c42 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 14 Aug 2019 10:48:23 +0800 Subject: [PATCH 03/11] fix bn gradient with use_global_stats --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index f294153ecc24..5ac131f43222 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -253,7 +253,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, } } - if (!ctx.is_train) { + if (!ctx.is_train || param.use_global_stats) { DType* omean = out_data[batchnorm::kMean].data().dptr(); DType* ovar = out_data[batchnorm::kVar].data().dptr(); DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr(); From 30863475a5f14869c3575002284f89cb44800e50 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 14 Aug 2019 10:50:29 +0800 Subject: [PATCH 04/11] rm useless comments --- src/operator/nn/batch_norm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 51e70c94d4ab..3214e3b9b9ac 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -394,7 +394,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam ¶m = nnvm::get(attrs.parsed); - // MKLDNN batchnorm only works well on the special MKLDNN layout. + if (SupportMKLDNNBN(inputs[0], param)) { std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); @@ -418,7 +418,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, const BatchNormParam ¶m = nnvm::get(attrs.parsed); mxnet::TShape shape = inputs[0].shape(); - // MKLDNN batchnorm only works well on the special MKLDNN layout. + if (SupportMKLDNNBN(inputs[0], param)) { std::vector out_grad(1); std::vector out_data(3); From 201d869d92dd57eec9afd2e0f84c5291013e4060 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Tue, 20 Aug 2019 10:36:44 +0800 Subject: [PATCH 05/11] fix mkldnn bn output when use_global_stats --- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 5ac131f43222..cb3a7566c078 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -51,7 +51,7 @@ using mkldnn::forward_inference; inline static unsigned _GetFlags(const std::vector &in_data, const std::vector &aux_states, - const BatchNormParam ¶m, bool is_train) { + const BatchNormParam ¶m, bool is_train_and_not_global_stats) { unsigned flags = 0U; if (in_data.size() == 3U) { flags |= use_scale_shift; @@ -59,7 +59,7 @@ inline static unsigned _GetFlags(const std::vector &in_data, // aux_states[0]: inMean // aux_states[1]: inVariance - if (aux_states.size() == 2U && !is_train) { + if (aux_states.size() == 2U && !is_train_and_not_global_stats) { flags |= use_global_stats; } return flags; @@ -107,13 +107,13 @@ class MKLDNNBNForward { std::shared_ptr mean_m; std::shared_ptr var_m; std::shared_ptr fwd; - bool is_train; + bool is_train_and_not_global_stats; t_bn_f_pdesc pd; public: - MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train): pd(_pd) { + MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) { weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc())); - this->is_train = is_train; + this->is_train_and_not_global_stats = is_train_and_not_global_stats; } const mkldnn::memory &GetWeight() const { @@ -161,7 +161,7 @@ class MKLDNNBNForward { } if (fwd == nullptr) { - if (!is_train) + if (!is_train_and_not_global_stats) fwd.reset(new mkldnn::batch_normalization_forward( pd, *data_m, mkldnn::primitive::at(*mean_m), mkldnn::primitive::at(*var_m), *weight_m, *out_m)); @@ -194,13 +194,14 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, #endif MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); + key.AddSign(param.use_global_stats); key.AddSign(in_data); auto it = fwds.find(key); if (it == fwds.end()) { auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train, (DType) param.eps, flags); - MKLDNNBNForward fwd(fwd_pd, ctx.is_train); + MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats); it = AddToCache(&fwds, key, fwd); } return it->second; @@ -213,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); @@ -378,7 +379,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; const NDArray &diff = out_grad[batchnorm::kOut]; From f0a5a7fefa6eec36345bafdc6b2304981a64aee7 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Tue, 20 Aug 2019 10:53:57 +0800 Subject: [PATCH 06/11] fix lint --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index cb3a7566c078..2d2bf2c64596 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -214,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); From 2c27914be7895e3b8eb34f593baf78e745dffb0c Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 21 Aug 2019 15:39:03 +0800 Subject: [PATCH 07/11] retrigger ci From 1307f62fbf71e3f3b89e8b1d5e4e8f4ae4ee2b59 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Thu, 22 Aug 2019 09:18:39 +0800 Subject: [PATCH 08/11] retrigger ci again From d894d3e9e568267f14c9223c073ac397f479592f Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Thu, 22 Aug 2019 12:35:06 +0800 Subject: [PATCH 09/11] retrigger ci again 2 From 8d69ee4ac17934fcd554b6235ff614c0e82d6f72 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Thu, 22 Aug 2019 13:19:32 +0800 Subject: [PATCH 10/11] retrigger ci again 3 From 84870410f48bd43803a883626f6f031b4cd32935 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Fri, 23 Aug 2019 08:28:24 +0800 Subject: [PATCH 11/11] retrigger ci again 4