From fac3b42f28e6dc96c0046092c190f1261e0f1039 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 5 Jun 2020 19:17:54 +0800 Subject: [PATCH 01/20] fix batch norm when fix_gamma is True --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 2 +- tests/python/unittest/test_operator.py | 35 +++++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 94a1572b8db3..f562f38d698c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -228,7 +228,7 @@ class CuDNNBatchNormOp { &a, &b, &a, - req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, + req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b, io_desc_, x.dptr_, io_desc_, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bdddf626b797..0ab2190a36f0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1842,8 +1842,10 @@ def test_batchnorm(): momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): - print(str((op, shape, axis, cudnn_off))) + def _test_batchnorm_impl(op, shape, axis, fix_gamma, + cudnn_off, output_mean_var): + print(str((op, shape, axis, fix_gamma, + cudnn_off, output_mean_var))) kwargs = dict(output_mean_var=output_mean_var) if op == mx.nd.contrib.SyncBatchNorm: @@ -1857,8 +1859,11 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) nch = shape[axis] - bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() + if not fix_gamma: + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() + else: + bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) bn_beta.attach_grad() @@ -1879,7 +1884,7 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, - fix_gamma=False, **kwargs) + fix_gamma=fix_gamma, **kwargs) if output_mean_var: output, output_mean, output_std = output output.backward(ograd) @@ -1945,18 +1950,24 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): assert_almost_equal(data.grad.asnumpy(), dX.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) assert_almost_equal( bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), + (24, 8, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, - cudnn_off, output_mean_var) + for fix_gamma in [False, True]: + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, fix_gamma, + cudnn_off, output_mean_var) @with_seed() From 34c6a3b4f04cc5b0c317a0695653adf03218b3a3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:15:37 +0800 Subject: [PATCH 02/20] support gradient accumulation for batch norm --- src/operator/nn/batch_norm-inl.h | 11 +++ src/operator/nn/batch_norm.cc | 84 ++++++++++++++----- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 30 +++++-- 3 files changed, 98 insertions(+), 27 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 17a16db5adcd..77ef666f7e04 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -259,6 +259,17 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, const std::vector &outputs) { CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); + + // check req + bool req_write_existed = false, req_addto_existed = false; + for (const OpReqType &r : req) { + if (IsBNWriting(r)) req_write_existed = true; + else if (r == kAddTo) req_addto_existed = true; + } + CHECK_EQ(req_write_existed && req_addto_existed, true) \ + << "BatchNorm does not support `grad_req` of two inputs \ +are `write` and `add` simultaneously"; + std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b865269fc6f5..b3097c2870f2 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3 &in_data, } } +template +static inline void ForEachFast(const BNTensor3 &in_data, + const BNTensor3 &in_data2, + const BNTensor3 &out_data, + const size_t channel, + OnData onData) { + const size_t num = in_data.OuterSize(); + const size_t matrixSize = in_data.InnerSize(); + const size_t skipLength = in_data.SkipLengthToNextSameChannelData(); + const size_t startOffset = in_data.StartOffset(channel); + + DType1 *data = in_data.dptr_ + startOffset; + DType2 *data2 = in_data2.dptr_ + startOffset; + DType3 *odata = out_data.dptr_ + startOffset; + + for (size_t outer = 0; outer < num; ++outer) { + for (size_t i = 0; i < matrixSize; ++i) { + onData(data++, data2++, odata++); + } + data += skipLength; + data2 += skipLength; + odata += skipLength; + } +} + } // namespace batchnorm /*! \brief Forward CPU */ @@ -263,7 +288,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, dotp += (*thisInputData - mean) * (*gradOut_data); }); - if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input + if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input if (is_train_and_not_global_stats) { // when in training mode // Q(X) = X - E[x] ; i.e. input centered to zero mean @@ -272,44 +297,59 @@ void BatchNormBackwardImpl(mshadow::Stream *, // projection of gradOutput on to output scaled by std const AccReal k = dotp * invstd * invstd / itemCount; - ForEachFast(inputData, gradIn, static_cast(channel), - [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { - *gradIn_data = (*inputDataPtr - mean) * k; - }); - const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; - }); + if (req[batchnorm::kData != kAddTo) { + ForEachFast(inputData, gradIn, static_cast(channel), + [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { + *gradIn_data = (*inputDataPtr - mean) * k; + }); + + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; + }); + } else { + ForEachFast(inputData, gradOut, gradIn, static_cast(channel), + [&mean, &k, iw, gradMean](const DType *inputDataPtr, + const DType *gradOut_data, + DType *gradIn_data) { + DType normal_val = (*inputDataPtr - mean) * k; + *gradIn_data += (*gradOut_data - gradMean - + normal_val) * iw; + }); + } } else { // when in evaluation mode // Q(X) = X - running_mean ; i.e. input centered to zero mean // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = *gradOut_data * iw; - }); + if (req[batchnorm::kData != kAddTo) { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = *gradOut_data * iw; + }); + } else { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data += *gradOut_data * iw; + }); + } } } // May want to make this a param eventually const AccReal scale = 1.0f; - if (IsBNWriting(req[batchnorm::kGamma])) { - if (!param_.fix_gamma) { - gradWeightData[channel] = scale * dotp * invstd; - } else { + if (!param_.fix_gamma) { + KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); + } else { + if (req[batchnorm::kGamma] != kNullOp) gradWeightData[channel] = AccReal(0); - } } - if (IsBNWriting(req[batchnorm::kBeta])) { - gradBiasData[channel] = scale * sumGradOut; - } + KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); } } diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d7401340f20f..bdbc47ffa03b 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -300,6 +300,7 @@ template void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { + CHECK_NE(req[batchnorm::kData], kAddTo) << "MKLDNN BatchNorm does not support `data.grad_req = add`"; if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -412,17 +413,36 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); - DType *w_grad_1 = in_grad[1].data().dptr(); - DType *w_grad_2 = in_grad[2].data().dptr(); + DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr(); + DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr(); + // the gradient of gamma if (!param.fix_gamma) { - memcpy(w_grad_1, gw_buf, copy_size); - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + if (req[batchnorm::kGamma] != kNullOp) { + if (req[batchnorm::kGamma] != kAddTo) { + memcpy(w_grad_1, gw_buf, copy_size); + } else { + for (int i = 0; i < channels_; i++) { + w_grad_1[i] += gw_buf[i]; + } + } + } } else { for (int i = 0; i < channels_; i++) { (in_grad[1].data().dptr())[i] = 0.0f; } - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } + + // the gradient of beta + if (req[batchnorm::kBeta] != kNullOp) { + if (req[batchnorm::kBeta] != kAddTo) { + memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } else { + DType *grad_beta = &gw_buf[channels_]; + for (int i = 0; i < channels_; i++) { + w_grad_2[i] += grad_beta[i]; + } + } } } else { LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; From 11c8e26f2057a758275d11411d21208e03f566b2 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:59:43 +0800 Subject: [PATCH 03/20] mkldnn batchnorm support grad add --- src/operator/nn/batch_norm.cc | 4 ++-- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b3097c2870f2..f46ab6b9e945 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -299,7 +299,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, const AccReal k = dotp * invstd * invstd / itemCount; const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - if (req[batchnorm::kData != kAddTo) { + if (req[batchnorm::kData] != kAddTo) { ForEachFast(inputData, gradIn, static_cast(channel), [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { *gradIn_data = (*inputDataPtr - mean) * k; @@ -325,7 +325,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - if (req[batchnorm::kData != kAddTo) { + if (req[batchnorm::kData] != kAddTo) { ForEachFast(gradOut, gradIn, static_cast(channel), [iw](const DType *gradOut_data, DType *gradIn_data) { *gradIn_data = *gradOut_data * iw; diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index bdbc47ffa03b..fdc4f9c946c4 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -300,7 +300,6 @@ template void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { - CHECK_NE(req[batchnorm::kData], kAddTo) << "MKLDNN BatchNorm does not support `data.grad_req = add`"; if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -348,7 +347,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); + auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), + bwd.GetDataPd().diff_src_desc(), req[batchnorm::kData]); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -369,7 +369,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; net_args[MKLDNN_ARG_SRC] = *data_mem; - net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second; net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; @@ -408,6 +408,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } From 830ba0dbced293edc8f93c4ec6b8e9b77f490d29 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:59:54 +0800 Subject: [PATCH 04/20] unittest for bn --- tests/python/unittest/test_operator.py | 42 +++++++++++++++++++------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0ab2190a36f0..1e500a1a17ad 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1843,8 +1843,10 @@ def test_batchnorm(): epsilon = 1e-5 def _test_batchnorm_impl(op, shape, axis, fix_gamma, + grad_req, cudnn_off, output_mean_var): print(str((op, shape, axis, fix_gamma, + grad_req, cudnn_off, output_mean_var))) kwargs = dict(output_mean_var=output_mean_var) @@ -1861,12 +1863,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, if not fix_gamma: bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() + bn_gamma.attach_grad(grad_req=grad_req) else: bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad() + bn_beta.attach_grad(grad_req=grad_req) bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1876,12 +1878,19 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] + other_data = mx.nd.zeros(shape=shape) + other_data.attach_grad(grad_req=grad_req) + adX, adW, adb = 0, 0, 0 for _ in range(num_iters): data = mx.nd.random.uniform(shape=shape) - data.attach_grad() + data.attach_grad(grad_req=grad_req) + if grad_req == 'add': + mixed_data = data + other_data + else: + mixed_data = data ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): - output = op(data, bn_gamma, bn_beta, + output = op(mixed_data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, fix_gamma=fix_gamma, **kwargs) @@ -1924,6 +1933,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) + if grad_req == 'add': + adX += dX + adW += dW + adb += db + else: + adX, adW, adb = dX, dW, db atol = 1e-2 rtol = 1e-2 @@ -1950,25 +1965,30 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, assert_almost_equal(data.grad.asnumpy(), dX.asnumpy(), atol=atol, rtol=rtol) + + if grad_req == 'add': + assert_almost_equal(other_data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), + bn_gamma.grad.asnumpy(), adW.asnumpy(), atol=atol, rtol=rtol) else: assert((bn_gamma.asnumpy() == 1).all()) assert_almost_equal( - bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): for fix_gamma in [False, True]: - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, fix_gamma, - cudnn_off, output_mean_var) - + for grad_req in ['write', 'add']: + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, fix_gamma, + grad_req, + cudnn_off, output_mean_var) @with_seed() def test_groupnorm(): From a9c91d20dc02aae8e6b43738d3b542a7e98cc664 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 14:27:34 +0800 Subject: [PATCH 05/20] fix bn arg --- src/operator/nn/batch_norm-inl.h | 2 +- tests/python/unittest/test_operator.py | 27 +++++++++++--------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 77ef666f7e04..d6eafaf90abb 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -266,7 +266,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, if (IsBNWriting(r)) req_write_existed = true; else if (r == kAddTo) req_addto_existed = true; } - CHECK_EQ(req_write_existed && req_addto_existed, true) \ + CHECK_EQ(req_write_existed && req_addto_existed, false) \ << "BatchNorm does not support `grad_req` of two inputs \ are `write` and `add` simultaneously"; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1e500a1a17ad..bd82347f40ea 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1878,19 +1878,16 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] - other_data = mx.nd.zeros(shape=shape) - other_data.attach_grad(grad_req=grad_req) + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=grad_req) adX, adW, adb = 0, 0, 0 for _ in range(num_iters): - data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) - if grad_req == 'add': - mixed_data = data + other_data - else: - mixed_data = data + if grad_req != 'add': + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=grad_req) ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): - output = op(mixed_data, bn_gamma, bn_beta, + output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, fix_gamma=fix_gamma, **kwargs) @@ -1940,8 +1937,10 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, else: adX, adW, adb = dX, dW, db - atol = 1e-2 - rtol = 1e-2 + if grad_req == 'add': + atol, rtol = 5e-1, 5e-1 + else: + atol, rtol = 1e-2, 1e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), @@ -1964,11 +1963,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, ), running_var.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(data.grad.asnumpy(), - dX.asnumpy(), atol=atol, rtol=rtol) - - if grad_req == 'add': - assert_almost_equal(other_data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: assert_almost_equal( bn_gamma.grad.asnumpy(), adW.asnumpy(), From 85946423f0184653577f5924f507e10446c6890c Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 14:36:23 +0800 Subject: [PATCH 06/20] fix lint --- src/operator/nn/batch_norm-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index d6eafaf90abb..9656784eeb4c 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -267,8 +267,8 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, else if (r == kAddTo) req_addto_existed = true; } CHECK_EQ(req_write_existed && req_addto_existed, false) \ - << "BatchNorm does not support `grad_req` of two inputs \ -are `write` and `add` simultaneously"; + << "BatchNorm does not support `grad_req` of two inputs" + "are `write` and `add` simultaneously"; std::vector out_grad(1); std::vector out_data(3); From f9995b096e68632963054fe962565bf5eeb74f7b Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:06:23 +0800 Subject: [PATCH 07/20] fix mkldnn --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index fdc4f9c946c4..d869542a0346 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -348,7 +348,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), - bwd.GetDataPd().diff_src_desc(), req[batchnorm::kData]); + bwd.pd.diff_src_desc(), req[batchnorm::kData]); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -403,12 +403,12 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } From d38044345f796a620c0e9f4ff30a318a939d9f4f Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:25:35 +0800 Subject: [PATCH 08/20] fix mkldnn bn --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 8 +++----- tests/python/unittest/test_operator.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d869542a0346..da4fd97e82da 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -402,15 +402,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - CommitOutput(gradIn, gradi_mem); - MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - MKLDNNStream::Get()->Submit(); } + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); + MKLDNNStream::Get()->Submit(); // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bd82347f40ea..8cf657a481e7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1938,7 +1938,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, adX, adW, adb = dX, dW, db if grad_req == 'add': - atol, rtol = 5e-1, 5e-1 + atol, rtol = 5e-2, 5e-2 else: atol, rtol = 1e-2, 1e-2 From eaeae2101c0d339fdcec7efadf4664e4b1c2bc86 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:32:37 +0800 Subject: [PATCH 09/20] fix grad when fixing gamma --- src/operator/nn/batch_norm.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index f46ab6b9e945..2b2b11e553fc 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -345,8 +345,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, if (!param_.fix_gamma) { KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); } else { - if (req[batchnorm::kGamma] != kNullOp) - gradWeightData[channel] = AccReal(0); + gradWeightData[channel] = AccReal(0); } KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); From 25924c956475104111d54b7dcb3df1ae396600c0 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 20:58:31 +0800 Subject: [PATCH 10/20] fix naive gpu bn --- src/operator/nn/batch_norm.cu | 66 ++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 84d1cecd93d9..2f990eb7b617 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -34,6 +34,9 @@ #define FIX_GAMMA_FLAG 8 #define IS_TRAINING_FLAG 16 #define USE_GLOBAL_STATS_FLAG 32 +#define ADDTO_DATA_FLAG (1 << 6) +#define ADDTO_GAMMA_FLAG (1 << 7) +#define ADDTO_BETA_FLAG (1 << 8) #if MXNET_USE_CUDNN == 1 #include "./cudnn/cudnn_batch_norm-inl.h" @@ -361,33 +364,58 @@ static __global__ void BatchNormalizationBackwardKernel( * momentum + localVariance * (AccReal(1) - momentum); } - if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); + if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) { + const bool grad_write = flags & WRITE_DATA_FLAG; + if (grad_write) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); + } + } + } + } else { + // grad addto + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) += + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) += ScalarConvert::to( + gradOut * gradScale); + } } } } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + if (flags & WRITE_GAMMA_FLAG) + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + else + tensors.gradWeight[plane] += ScalarConvert::to(dotP * invstd); } else { tensors.gradWeight[plane] = DType(0); } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { + if (flags & WRITE_BETA_FLAG) + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + else + tensors.gradBias[plane] += ScalarConvert::to(gradOutputSum); } } @@ -579,12 +607,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx, flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0; if (IsBNWriting(req[batchnorm::kData])) { flags |= WRITE_DATA_FLAG; + } else if (req[batchnorm::kData] == kAddTo) { + flags |= ADDTO_DATA_FLAG; } if (IsBNWriting(req[batchnorm::kGamma])) { flags |= WRITE_GAMMA_FLAG; + } else if (req[batchnorm::kGamma] == kAddTo) { + flags |= ADDTO_GAMMA_FLAG; } if (IsBNWriting(req[batchnorm::kBeta])) { flags |= WRITE_BETA_FLAG; + } else if (req[batchnorm::kBeta] == kAddTo) { + flags |= ADDTO_BETA_FLAG; } return flags; } From e3bb53c44ef91d6b4e35daba75f74f54dd4dd5d4 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 21:30:22 +0800 Subject: [PATCH 11/20] fix lint --- src/operator/nn/batch_norm.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 2f990eb7b617..0875f05e669d 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -400,7 +400,8 @@ static __global__ void BatchNormalizationBackwardKernel( } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { if (flags & WRITE_GAMMA_FLAG) tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); @@ -411,7 +412,8 @@ static __global__ void BatchNormalizationBackwardKernel( } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { if (flags & WRITE_BETA_FLAG) tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); else From 9c0056723cb1fa897e4b16271176d5ee92a3b6cc Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 00:11:21 +0800 Subject: [PATCH 12/20] invoke mkldnn and cudnn batchnorm when axis != 1 --- src/operator/nn/batch_norm.cc | 3 +- src/operator/nn/batch_norm.cu | 6 +-- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 32 ++++++++++---- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 44 ++++++++++++++++--- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 2b2b11e553fc..833d5b1f311f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -419,8 +419,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) && shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS + return SupportMKLDNN(input) && !mxnet::op::batchnorm::disable_mkl; } diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 0875f05e669d..c7e991f98d18 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -698,8 +698,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) @@ -727,8 +726,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index f562f38d698c..71f59babe331 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -251,15 +251,31 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; + CHECK_GE(param_.axis, 0); + CHECK_LT(param_.axis, in_data.ndim()); + if (param_.axis == 1) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + } } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; - shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + if (in_data.ndim() == 4 && param_.axis == 1) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; + shape_[2] = 1; + shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, + static_cast(in_data.ndim()))); + } } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index da4fd97e82da..0a29a6d87de6 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - const std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + + mxnet::TShape shape = inputs[batchnorm::kData].shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + NDArray out = outputs[batchnorm::kOut]; + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); + out = out.Reshape(new_shape); + } + const std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, fuse_relu); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - const NDArray &diff = out_grad[batchnorm::kOut]; - const NDArray &gradIn = in_grad[batchnorm::kData]; + NDArray data = in_data[batchnorm::kData]; + NDArray diff = out_grad[batchnorm::kOut]; + NDArray gradIn = in_grad[batchnorm::kData]; const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; const NDArray &out_mean = out_data[batchnorm::kMean]; @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, CHECK(moving_mean.IsDefaultData()); CHECK(moving_var.IsDefaultData()); + mxnet::TShape shape = data.shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + data = data.Reshape(new_shape); + diff = diff.Reshape(new_shape); + gradIn = gradIn.Reshape(new_shape); + } + auto data_mem = data.GetMKLDNNData(); auto diff_mem = diff.GetMKLDNNData(); // MKLDNN batchnorm should run on special layouts. If one of them isn't, we From 4cab1ed5157547075bbb25fdc9adc2ab6a6f7314 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Tue, 9 Jun 2020 01:41:35 +0800 Subject: [PATCH 13/20] backport 18500 --- src/operator/nn/batch_norm-inl.h | 10 -- src/operator/nn/batch_norm.cc | 4 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 15 +- tests/python/unittest/test_numpy_op.py | 153 +++++++++++++++++++ tests/python/unittest/test_operator.py | 106 +++++++------ 5 files changed, 226 insertions(+), 62 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 9656784eeb4c..485b3b33f6a8 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -260,16 +260,6 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); - // check req - bool req_write_existed = false, req_addto_existed = false; - for (const OpReqType &r : req) { - if (IsBNWriting(r)) req_write_existed = true; - else if (r == kAddTo) req_addto_existed = true; - } - CHECK_EQ(req_write_existed && req_addto_existed, false) \ - << "BatchNorm does not support `grad_req` of two inputs" - "are `write` and `add` simultaneously"; - std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 833d5b1f311f..588c4ed9625e 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -345,7 +345,9 @@ void BatchNormBackwardImpl(mshadow::Stream *, if (!param_.fix_gamma) { KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); } else { - gradWeightData[channel] = AccReal(0); + if (IsBNWriting(req[batchnorm::kGamma])) { + gradWeightData[channel] = AccReal(0); + } } KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 71f59babe331..d39011ce7f3c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -222,13 +222,24 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; + bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) || + (req[cudnnbatchnorm::kBeta] == kAddTo); + if (grad_add_gamma_beta) { + if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { + dgamma = 0.f; + } + if (IsBNWriting(req[cudnnbatchnorm::kBeta])) { + dbeta = 0.f; + } + } + CUDNN_CALL(cudnnBatchNormalizationBackward( s->dnn_handle_, mode, &a, - &b, + req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, &a, - req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b, + grad_add_gamma_beta ? &b_add : &b, // gamma and beta io_desc_, x.dptr_, io_desc_, diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 45b6a9c7c217..550b6dd42d32 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1596,6 +1596,159 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, transpose_b=transpose_b)) +@with_seed() +@use_np +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), + (24, 8, 4, 5), (24, 5, 6, 4, 5)]) +@pytest.mark.parametrize('fix_gamma', [False, True]) +@pytest.mark.parametrize('cudnn_off', [False, True]) +@pytest.mark.parametrize('output_mean_var', [False, True]) +def test_npx_batch_norm(shape, fix_gamma, cudnn_off, output_mean_var): + momentum = 0.9 + epsilon = 1e-5 + class TestBatchNorm(HybridBlock): + def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs): + super().__init__() + self.eps = eps + self.fix_gamma = fix_gamma + self.momentum = momentum + self.kwargs = kwargs + def hybrid_forward(self, F, data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var): + op = F.npx.batch_norm + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, + momentum=self.momentum, eps=self.eps, + fix_gamma=self.fix_gamma, **self.kwargs) + return output + + def _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req): + kwargs = dict(output_mean_var=output_mean_var) + kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) + op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs) + nch = shape[axis] + + if not fix_gamma: + bn_gamma = np.random.uniform(size=(nch,)) + bn_gamma.attach_grad(grad_req=gamma_grad_req) + else: + bn_gamma = np.ones((nch,)) + + bn_beta = np.random.uniform(size=(nch,)) + bn_beta.attach_grad(grad_req=beta_grad_req) + + bn_running_mean = np.zeros(nch) + bn_running_var = np.ones(nch) + + running_mean = np.zeros(nch) + running_var = np.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + expand_shape = tuple(expand_shape) + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' + for _ in range(num_iters): + if data_grad_req != 'add': + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + ograd = np.random.uniform(size=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var) + if output_mean_var: + output, output_mean, output_std = output + if is_train: + output.backward(ograd) + mx.nd.waitall() + + assert 0 <= axis < data.ndim + reduce_axis = tuple(i for i in range(data.ndim) if i != axis) + assert len(reduce_axis) == data.ndim - 1 + data_mean = data.mean( + axis=reduce_axis, keepdims=True) + data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis, + keepdims=True) + + target_output = (data - data_mean) / \ + np.sqrt(data_var + epsilon) * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / np.sqrt(data_var + epsilon) + nx = xsm * nd + m = _np.prod(shape) / shape[axis] + dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True, + ) * (-0.5) * np.power(nd, 3) + dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \ + dvar * xsm.mean(axis=reduce_axis, keepdims=True, + ) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = (ograd * nx).sum(axis=reduce_axis) + db = ograd.sum(axis=reduce_axis) + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db + + atol, rtol = 5e-2, 5e-2 + + if output_mean_var: + assert_almost_equal(output_mean.asnumpy(), + data_mean_flat.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output_std.asnumpy(), + (1.0 / np.sqrt(data_var_flat + + epsilon)).asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) + @with_seed() @use_np def test_npx_softmax(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 8cf657a481e7..c4385880df64 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1838,19 +1838,26 @@ def check_batchnorm_training(stype): @xfail_when_nonstandard_decimal_separator @with_seed() -def test_batchnorm(): +@pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm']) +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), + (24, 8, 4, 5), (24, 5, 6, 4, 5)]) +@pytest.mark.parametrize('fix_gamma', [False, True]) +@pytest.mark.parametrize('cudnn_off', [False, True]) +@pytest.mark.parametrize('output_mean_var', [False, True]) +def test_batchnorm(op_name, shape, fix_gamma, cudnn_off, output_mean_var): + if op_name == 'BatchNorm': + op = mx.nd.BatchNorm + elif op_name == 'SyncBatchNorm': + op = mx.nd.contrib.SyncBatchNorm + else: + raise ValueError(f'Not supported {op_name}') momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var): - print(str((op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var))) - + def _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req): kwargs = dict(output_mean_var=output_mean_var) - if op == mx.nd.contrib.SyncBatchNorm: + if op_name == 'SyncBatchNorm': if axis != 1: return key = str(op) + str(shape) + str(axis) @@ -1863,12 +1870,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, if not fix_gamma: bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad(grad_req=grad_req) + bn_gamma.attach_grad(grad_req=gamma_grad_req) else: bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad(grad_req=grad_req) + bn_beta.attach_grad(grad_req=beta_grad_req) bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1879,12 +1886,15 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) + data.attach_grad(grad_req=data_grad_req) adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' for _ in range(num_iters): - if grad_req != 'add': + if data_grad_req != 'add': data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) + data.attach_grad(grad_req=data_grad_req) ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): output = op(data, bn_gamma, bn_beta, @@ -1893,7 +1903,8 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, fix_gamma=fix_gamma, **kwargs) if output_mean_var: output, output_mean, output_std = output - output.backward(ograd) + if is_train: + output.backward(ograd) mx.nd.waitall() data_mean = data.mean( @@ -1930,17 +1941,11 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) - if grad_req == 'add': - adX += dX - adW += dW - adb += db - else: - adX, adW, adb = dX, dW, db + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db - if grad_req == 'add': - atol, rtol = 5e-2, 5e-2 - else: - atol, rtol = 1e-2, 1e-2 + atol, rtol = 5e-2, 5e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), @@ -1957,33 +1962,36 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, atol=atol, rtol=rtol) assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - assert_almost_equal(data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: - assert_almost_equal( - bn_gamma.grad.asnumpy(), adW.asnumpy(), - atol=atol, rtol=rtol) + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) else: assert((bn_gamma.asnumpy() == 1).all()) - assert_almost_equal( - bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - - for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), - (24, 8, 4, 4), (24, 5, 6, 4, 4)]: - for axis in range(len(shape)): - for fix_gamma in [False, True]: - for grad_req in ['write', 'add']: - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) + @with_seed() def test_groupnorm(): From b78b4cba232cb5fa662fa97ac6d274c58a5d46af Mon Sep 17 00:00:00 2001 From: wkcn Date: Tue, 9 Jun 2020 07:20:29 +0800 Subject: [PATCH 14/20] change condition --- src/operator/nn/batch_norm.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 588c4ed9625e..be444602da03 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -420,9 +420,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) - && !mxnet::op::batchnorm::disable_mkl; + if (mxnet::op::batchnorm::disable_mkl) return false; + const mxnet::TShape shape = input.shape(); + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) return false; + const int dtype = input.dtype(); + return (dtype == mshadow::kFloat32 || + dtype == mshadow::mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, From e4118901fa70f1926bfca3af85041c7ed05a58c8 Mon Sep 17 00:00:00 2001 From: wkcn Date: Tue, 9 Jun 2020 08:09:38 +0800 Subject: [PATCH 15/20] fix --- src/operator/nn/batch_norm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index be444602da03..1c55f3602164 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -426,7 +426,7 @@ static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &p if (ndim == 0 || shape.Size() == 0) return false; const int dtype = input.dtype(); return (dtype == mshadow::kFloat32 || - dtype == mshadow::mshadow::kBfloat16) && + dtype == mshadow::kBfloat16) && SupportStorageMKLDNN(input.storage_type()); } From aad431b0d670feeab1ab1ae52f7ebf248df54409 Mon Sep 17 00:00:00 2001 From: wkcn Date: Tue, 9 Jun 2020 17:30:50 +0800 Subject: [PATCH 16/20] fix --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index d39011ce7f3c..57626019a3cd 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -276,17 +276,12 @@ class CuDNNBatchNormOp { shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); } } else { - if (in_data.ndim() == 4 && param_.axis == 1) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; - } else { - // reshape to (N, C, 1, D), C is the `param_.axis` dimension - shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); - shape_[1] = in_data.shape_[param_.axis]; - shape_[2] = 1; - shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, - static_cast(in_data.ndim()))); - } + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; + shape_[2] = 1; + shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, + static_cast(in_data.ndim()))); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, From 939750f882144bdd268ca1995d79784959728d83 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 12 Jun 2020 01:54:49 +0800 Subject: [PATCH 17/20] add mkldnn_off for bn --- src/operator/nn/batch_norm-inl.h | 4 ++++ src/operator/nn/batch_norm.cc | 2 +- tests/python/unittest/save_000800.json | 5 +++-- tests/python/unittest/test_symbol.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 485b3b33f6a8..a80b7330f521 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -74,6 +74,7 @@ struct BatchNormParam : public dmlc::Parameter { bool output_mean_var; int axis; bool cudnn_off; + bool mkldnn_off; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset @@ -96,6 +97,8 @@ struct BatchNormParam : public dmlc::Parameter { .describe("Specify which shape axis the channel is specified"); DMLC_DECLARE_FIELD(cudnn_off).set_default(false) .describe("Do not select CUDNN operator, if available"); + DMLC_DECLARE_FIELD(mkldnn_off).set_default(false) + .describe("Do not select MKLDNN operator, if available"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " @@ -116,6 +119,7 @@ struct BatchNormParam : public dmlc::Parameter { this->use_global_stats == other.use_global_stats && this->output_mean_var == other.output_mean_var && this->axis == other.axis && this->cudnn_off == other.cudnn_off && + this->mkldnn_off == other.mkldnn_off && this->min_calib_range.has_value() == other.min_calib_range.has_value() && this->max_calib_range.has_value() == other.max_calib_range.has_value(); if (this->min_calib_range.has_value() && other.min_calib_range.has_value() && diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 1c55f3602164..d2d11b9c7248 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -420,7 +420,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - if (mxnet::op::batchnorm::disable_mkl) return false; + if (mxnet::op::batchnorm::disable_mkl || param.mkldnn_off) return false; const mxnet::TShape shape = input.shape(); const int ndim = shape.ndim(); if (ndim == 0 || shape.Size() == 0) return false; diff --git a/tests/python/unittest/save_000800.json b/tests/python/unittest/save_000800.json index 7b385e2983d8..e0be90f65dca 100644 --- a/tests/python/unittest/save_000800.json +++ b/tests/python/unittest/save_000800.json @@ -151,7 +151,8 @@ "eps": "0.001", "fix_gamma": "True", "momentum": "0.9", - "use_global_stats": "False" + "use_global_stats": "False", + "mkldnn_off": "True" }, "name": "batchnorm0", "inputs": [[11, 0], [12, 0], [13, 0]], @@ -185,4 +186,4 @@ ], "arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15], "heads": [[16, 0]] -} \ No newline at end of file +} diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index b5205787d1ba..63e7eca1484a 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -284,7 +284,7 @@ def test_load_000800(): fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01) act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) - fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0') + fc3 = mx.symbol.BatchNorm(fc3, mkldnn_off=True, name='batchnorm0') sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) From 25e4666678db8c5f8032d26cbdc108ee6a7d219d Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 1 Jul 2020 08:18:52 +0800 Subject: [PATCH 18/20] remove mkldnn_off --- src/operator/nn/batch_norm-inl.h | 4 ---- src/operator/nn/batch_norm.cc | 2 +- tests/python/unittest/save_000800.json | 3 +-- tests/python/unittest/test_symbol.py | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index a80b7330f521..485b3b33f6a8 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -74,7 +74,6 @@ struct BatchNormParam : public dmlc::Parameter { bool output_mean_var; int axis; bool cudnn_off; - bool mkldnn_off; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset @@ -97,8 +96,6 @@ struct BatchNormParam : public dmlc::Parameter { .describe("Specify which shape axis the channel is specified"); DMLC_DECLARE_FIELD(cudnn_off).set_default(false) .describe("Do not select CUDNN operator, if available"); - DMLC_DECLARE_FIELD(mkldnn_off).set_default(false) - .describe("Do not select MKLDNN operator, if available"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " @@ -119,7 +116,6 @@ struct BatchNormParam : public dmlc::Parameter { this->use_global_stats == other.use_global_stats && this->output_mean_var == other.output_mean_var && this->axis == other.axis && this->cudnn_off == other.cudnn_off && - this->mkldnn_off == other.mkldnn_off && this->min_calib_range.has_value() == other.min_calib_range.has_value() && this->max_calib_range.has_value() == other.max_calib_range.has_value(); if (this->min_calib_range.has_value() && other.min_calib_range.has_value() && diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index d2d11b9c7248..1c55f3602164 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -420,7 +420,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - if (mxnet::op::batchnorm::disable_mkl || param.mkldnn_off) return false; + if (mxnet::op::batchnorm::disable_mkl) return false; const mxnet::TShape shape = input.shape(); const int ndim = shape.ndim(); if (ndim == 0 || shape.Size() == 0) return false; diff --git a/tests/python/unittest/save_000800.json b/tests/python/unittest/save_000800.json index e0be90f65dca..56cabc538eea 100644 --- a/tests/python/unittest/save_000800.json +++ b/tests/python/unittest/save_000800.json @@ -151,8 +151,7 @@ "eps": "0.001", "fix_gamma": "True", "momentum": "0.9", - "use_global_stats": "False", - "mkldnn_off": "True" + "use_global_stats": "False" }, "name": "batchnorm0", "inputs": [[11, 0], [12, 0], [13, 0]], diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 63e7eca1484a..b5205787d1ba 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -284,7 +284,7 @@ def test_load_000800(): fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01) act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) - fc3 = mx.symbol.BatchNorm(fc3, mkldnn_off=True, name='batchnorm0') + fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0') sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) From 97c6746b3bfacb95387aa3142d39eea162d957df Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 1 Jul 2020 08:44:04 +0800 Subject: [PATCH 19/20] recover save_000800.json --- tests/python/unittest/save_000800.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/save_000800.json b/tests/python/unittest/save_000800.json index 56cabc538eea..7b385e2983d8 100644 --- a/tests/python/unittest/save_000800.json +++ b/tests/python/unittest/save_000800.json @@ -185,4 +185,4 @@ ], "arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15], "heads": [[16, 0]] -} +} \ No newline at end of file From 38da95f9921e0be7eba8aadedbcb8c525d71a436 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 5 Jul 2020 18:45:57 +0800 Subject: [PATCH 20/20] cast --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 57626019a3cd..340c2f3494f2 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -273,7 +273,8 @@ class CuDNNBatchNormOp { shape_[0] = in_data.shape_[0]; shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + shape_[3] = static_cast(in_data.shape_.ProdShape(2, + in_data.ndim())); } } else { // reshape to (N, C, 1, D), C is the `param_.axis` dimension @@ -281,7 +282,7 @@ class CuDNNBatchNormOp { shape_[1] = in_data.shape_[param_.axis]; shape_[2] = 1; shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, - static_cast(in_data.ndim()))); + in_data.ndim())); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,