From fac3b42f28e6dc96c0046092c190f1261e0f1039 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 5 Jun 2020 19:17:54 +0800 Subject: [PATCH 01/23] fix batch norm when fix_gamma is True --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 2 +- tests/python/unittest/test_operator.py | 35 +++++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 94a1572b8db3..f562f38d698c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -228,7 +228,7 @@ class CuDNNBatchNormOp { &a, &b, &a, - req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, + req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b, io_desc_, x.dptr_, io_desc_, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bdddf626b797..0ab2190a36f0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1842,8 +1842,10 @@ def test_batchnorm(): momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): - print(str((op, shape, axis, cudnn_off))) + def _test_batchnorm_impl(op, shape, axis, fix_gamma, + cudnn_off, output_mean_var): + print(str((op, shape, axis, fix_gamma, + cudnn_off, output_mean_var))) kwargs = dict(output_mean_var=output_mean_var) if op == mx.nd.contrib.SyncBatchNorm: @@ -1857,8 +1859,11 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) nch = shape[axis] - bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() + if not fix_gamma: + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() + else: + bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) bn_beta.attach_grad() @@ -1879,7 +1884,7 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, - fix_gamma=False, **kwargs) + fix_gamma=fix_gamma, **kwargs) if output_mean_var: output, output_mean, output_std = output output.backward(ograd) @@ -1945,18 +1950,24 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): assert_almost_equal(data.grad.asnumpy(), dX.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) assert_almost_equal( bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), + (24, 8, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, - cudnn_off, output_mean_var) + for fix_gamma in [False, True]: + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, fix_gamma, + cudnn_off, output_mean_var) @with_seed() From 34c6a3b4f04cc5b0c317a0695653adf03218b3a3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:15:37 +0800 Subject: [PATCH 02/23] support gradient accumulation for batch norm --- src/operator/nn/batch_norm-inl.h | 11 +++ src/operator/nn/batch_norm.cc | 84 ++++++++++++++----- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 30 +++++-- 3 files changed, 98 insertions(+), 27 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 17a16db5adcd..77ef666f7e04 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -259,6 +259,17 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, const std::vector &outputs) { CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); + + // check req + bool req_write_existed = false, req_addto_existed = false; + for (const OpReqType &r : req) { + if (IsBNWriting(r)) req_write_existed = true; + else if (r == kAddTo) req_addto_existed = true; + } + CHECK_EQ(req_write_existed && req_addto_existed, true) \ + << "BatchNorm does not support `grad_req` of two inputs \ +are `write` and `add` simultaneously"; + std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b865269fc6f5..b3097c2870f2 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3 &in_data, } } +template +static inline void ForEachFast(const BNTensor3 &in_data, + const BNTensor3 &in_data2, + const BNTensor3 &out_data, + const size_t channel, + OnData onData) { + const size_t num = in_data.OuterSize(); + const size_t matrixSize = in_data.InnerSize(); + const size_t skipLength = in_data.SkipLengthToNextSameChannelData(); + const size_t startOffset = in_data.StartOffset(channel); + + DType1 *data = in_data.dptr_ + startOffset; + DType2 *data2 = in_data2.dptr_ + startOffset; + DType3 *odata = out_data.dptr_ + startOffset; + + for (size_t outer = 0; outer < num; ++outer) { + for (size_t i = 0; i < matrixSize; ++i) { + onData(data++, data2++, odata++); + } + data += skipLength; + data2 += skipLength; + odata += skipLength; + } +} + } // namespace batchnorm /*! \brief Forward CPU */ @@ -263,7 +288,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, dotp += (*thisInputData - mean) * (*gradOut_data); }); - if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input + if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input if (is_train_and_not_global_stats) { // when in training mode // Q(X) = X - E[x] ; i.e. input centered to zero mean @@ -272,44 +297,59 @@ void BatchNormBackwardImpl(mshadow::Stream *, // projection of gradOutput on to output scaled by std const AccReal k = dotp * invstd * invstd / itemCount; - ForEachFast(inputData, gradIn, static_cast(channel), - [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { - *gradIn_data = (*inputDataPtr - mean) * k; - }); - const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; - }); + if (req[batchnorm::kData != kAddTo) { + ForEachFast(inputData, gradIn, static_cast(channel), + [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { + *gradIn_data = (*inputDataPtr - mean) * k; + }); + + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; + }); + } else { + ForEachFast(inputData, gradOut, gradIn, static_cast(channel), + [&mean, &k, iw, gradMean](const DType *inputDataPtr, + const DType *gradOut_data, + DType *gradIn_data) { + DType normal_val = (*inputDataPtr - mean) * k; + *gradIn_data += (*gradOut_data - gradMean - + normal_val) * iw; + }); + } } else { // when in evaluation mode // Q(X) = X - running_mean ; i.e. input centered to zero mean // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = *gradOut_data * iw; - }); + if (req[batchnorm::kData != kAddTo) { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = *gradOut_data * iw; + }); + } else { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data += *gradOut_data * iw; + }); + } } } // May want to make this a param eventually const AccReal scale = 1.0f; - if (IsBNWriting(req[batchnorm::kGamma])) { - if (!param_.fix_gamma) { - gradWeightData[channel] = scale * dotp * invstd; - } else { + if (!param_.fix_gamma) { + KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); + } else { + if (req[batchnorm::kGamma] != kNullOp) gradWeightData[channel] = AccReal(0); - } } - if (IsBNWriting(req[batchnorm::kBeta])) { - gradBiasData[channel] = scale * sumGradOut; - } + KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); } } diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d7401340f20f..bdbc47ffa03b 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -300,6 +300,7 @@ template void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { + CHECK_NE(req[batchnorm::kData], kAddTo) << "MKLDNN BatchNorm does not support `data.grad_req = add`"; if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -412,17 +413,36 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); - DType *w_grad_1 = in_grad[1].data().dptr(); - DType *w_grad_2 = in_grad[2].data().dptr(); + DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr(); + DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr(); + // the gradient of gamma if (!param.fix_gamma) { - memcpy(w_grad_1, gw_buf, copy_size); - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + if (req[batchnorm::kGamma] != kNullOp) { + if (req[batchnorm::kGamma] != kAddTo) { + memcpy(w_grad_1, gw_buf, copy_size); + } else { + for (int i = 0; i < channels_; i++) { + w_grad_1[i] += gw_buf[i]; + } + } + } } else { for (int i = 0; i < channels_; i++) { (in_grad[1].data().dptr())[i] = 0.0f; } - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } + + // the gradient of beta + if (req[batchnorm::kBeta] != kNullOp) { + if (req[batchnorm::kBeta] != kAddTo) { + memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } else { + DType *grad_beta = &gw_buf[channels_]; + for (int i = 0; i < channels_; i++) { + w_grad_2[i] += grad_beta[i]; + } + } } } else { LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; From 11c8e26f2057a758275d11411d21208e03f566b2 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:59:43 +0800 Subject: [PATCH 03/23] mkldnn batchnorm support grad add --- src/operator/nn/batch_norm.cc | 4 ++-- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b3097c2870f2..f46ab6b9e945 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -299,7 +299,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, const AccReal k = dotp * invstd * invstd / itemCount; const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - if (req[batchnorm::kData != kAddTo) { + if (req[batchnorm::kData] != kAddTo) { ForEachFast(inputData, gradIn, static_cast(channel), [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { *gradIn_data = (*inputDataPtr - mean) * k; @@ -325,7 +325,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - if (req[batchnorm::kData != kAddTo) { + if (req[batchnorm::kData] != kAddTo) { ForEachFast(gradOut, gradIn, static_cast(channel), [iw](const DType *gradOut_data, DType *gradIn_data) { *gradIn_data = *gradOut_data * iw; diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index bdbc47ffa03b..fdc4f9c946c4 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -300,7 +300,6 @@ template void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { - CHECK_NE(req[batchnorm::kData], kAddTo) << "MKLDNN BatchNorm does not support `data.grad_req = add`"; if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -348,7 +347,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); + auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), + bwd.GetDataPd().diff_src_desc(), req[batchnorm::kData]); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -369,7 +369,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; net_args[MKLDNN_ARG_SRC] = *data_mem; - net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second; net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; @@ -408,6 +408,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } From 830ba0dbced293edc8f93c4ec6b8e9b77f490d29 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 10:59:54 +0800 Subject: [PATCH 04/23] unittest for bn --- tests/python/unittest/test_operator.py | 42 +++++++++++++++++++------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0ab2190a36f0..1e500a1a17ad 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1843,8 +1843,10 @@ def test_batchnorm(): epsilon = 1e-5 def _test_batchnorm_impl(op, shape, axis, fix_gamma, + grad_req, cudnn_off, output_mean_var): print(str((op, shape, axis, fix_gamma, + grad_req, cudnn_off, output_mean_var))) kwargs = dict(output_mean_var=output_mean_var) @@ -1861,12 +1863,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, if not fix_gamma: bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() + bn_gamma.attach_grad(grad_req=grad_req) else: bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad() + bn_beta.attach_grad(grad_req=grad_req) bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1876,12 +1878,19 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] + other_data = mx.nd.zeros(shape=shape) + other_data.attach_grad(grad_req=grad_req) + adX, adW, adb = 0, 0, 0 for _ in range(num_iters): data = mx.nd.random.uniform(shape=shape) - data.attach_grad() + data.attach_grad(grad_req=grad_req) + if grad_req == 'add': + mixed_data = data + other_data + else: + mixed_data = data ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): - output = op(data, bn_gamma, bn_beta, + output = op(mixed_data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, fix_gamma=fix_gamma, **kwargs) @@ -1924,6 +1933,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) + if grad_req == 'add': + adX += dX + adW += dW + adb += db + else: + adX, adW, adb = dX, dW, db atol = 1e-2 rtol = 1e-2 @@ -1950,25 +1965,30 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, assert_almost_equal(data.grad.asnumpy(), dX.asnumpy(), atol=atol, rtol=rtol) + + if grad_req == 'add': + assert_almost_equal(other_data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), + bn_gamma.grad.asnumpy(), adW.asnumpy(), atol=atol, rtol=rtol) else: assert((bn_gamma.asnumpy() == 1).all()) assert_almost_equal( - bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): for fix_gamma in [False, True]: - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, fix_gamma, - cudnn_off, output_mean_var) - + for grad_req in ['write', 'add']: + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, fix_gamma, + grad_req, + cudnn_off, output_mean_var) @with_seed() def test_groupnorm(): From a9c91d20dc02aae8e6b43738d3b542a7e98cc664 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 14:27:34 +0800 Subject: [PATCH 05/23] fix bn arg --- src/operator/nn/batch_norm-inl.h | 2 +- tests/python/unittest/test_operator.py | 27 +++++++++++--------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 77ef666f7e04..d6eafaf90abb 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -266,7 +266,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, if (IsBNWriting(r)) req_write_existed = true; else if (r == kAddTo) req_addto_existed = true; } - CHECK_EQ(req_write_existed && req_addto_existed, true) \ + CHECK_EQ(req_write_existed && req_addto_existed, false) \ << "BatchNorm does not support `grad_req` of two inputs \ are `write` and `add` simultaneously"; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1e500a1a17ad..bd82347f40ea 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1878,19 +1878,16 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] - other_data = mx.nd.zeros(shape=shape) - other_data.attach_grad(grad_req=grad_req) + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=grad_req) adX, adW, adb = 0, 0, 0 for _ in range(num_iters): - data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) - if grad_req == 'add': - mixed_data = data + other_data - else: - mixed_data = data + if grad_req != 'add': + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=grad_req) ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): - output = op(mixed_data, bn_gamma, bn_beta, + output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, fix_gamma=fix_gamma, **kwargs) @@ -1940,8 +1937,10 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, else: adX, adW, adb = dX, dW, db - atol = 1e-2 - rtol = 1e-2 + if grad_req == 'add': + atol, rtol = 5e-1, 5e-1 + else: + atol, rtol = 1e-2, 1e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), @@ -1964,11 +1963,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, ), running_var.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(data.grad.asnumpy(), - dX.asnumpy(), atol=atol, rtol=rtol) - - if grad_req == 'add': - assert_almost_equal(other_data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: assert_almost_equal( bn_gamma.grad.asnumpy(), adW.asnumpy(), From 85946423f0184653577f5924f507e10446c6890c Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 14:36:23 +0800 Subject: [PATCH 06/23] fix lint --- src/operator/nn/batch_norm-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index d6eafaf90abb..9656784eeb4c 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -267,8 +267,8 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, else if (r == kAddTo) req_addto_existed = true; } CHECK_EQ(req_write_existed && req_addto_existed, false) \ - << "BatchNorm does not support `grad_req` of two inputs \ -are `write` and `add` simultaneously"; + << "BatchNorm does not support `grad_req` of two inputs" + "are `write` and `add` simultaneously"; std::vector out_grad(1); std::vector out_data(3); From f9995b096e68632963054fe962565bf5eeb74f7b Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:06:23 +0800 Subject: [PATCH 07/23] fix mkldnn --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index fdc4f9c946c4..d869542a0346 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -348,7 +348,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), - bwd.GetDataPd().diff_src_desc(), req[batchnorm::kData]); + bwd.pd.diff_src_desc(), req[batchnorm::kData]); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -403,12 +403,12 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - CommitOutput(gradIn, gradi_mem); MKLDNNStream::Get()->Submit(); } From d38044345f796a620c0e9f4ff30a318a939d9f4f Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:25:35 +0800 Subject: [PATCH 08/23] fix mkldnn bn --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 8 +++----- tests/python/unittest/test_operator.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d869542a0346..da4fd97e82da 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -402,15 +402,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - CommitOutput(gradIn, gradi_mem); - MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - MKLDNNStream::Get()->Submit(); } + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); + MKLDNNStream::Get()->Submit(); // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bd82347f40ea..8cf657a481e7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1938,7 +1938,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, adX, adW, adb = dX, dW, db if grad_req == 'add': - atol, rtol = 5e-1, 5e-1 + atol, rtol = 5e-2, 5e-2 else: atol, rtol = 1e-2, 1e-2 From eaeae2101c0d339fdcec7efadf4664e4b1c2bc86 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 17:32:37 +0800 Subject: [PATCH 09/23] fix grad when fixing gamma --- src/operator/nn/batch_norm.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index f46ab6b9e945..2b2b11e553fc 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -345,8 +345,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, if (!param_.fix_gamma) { KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); } else { - if (req[batchnorm::kGamma] != kNullOp) - gradWeightData[channel] = AccReal(0); + gradWeightData[channel] = AccReal(0); } KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); From 25924c956475104111d54b7dcb3df1ae396600c0 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 20:58:31 +0800 Subject: [PATCH 10/23] fix naive gpu bn --- src/operator/nn/batch_norm.cu | 66 ++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 84d1cecd93d9..2f990eb7b617 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -34,6 +34,9 @@ #define FIX_GAMMA_FLAG 8 #define IS_TRAINING_FLAG 16 #define USE_GLOBAL_STATS_FLAG 32 +#define ADDTO_DATA_FLAG (1 << 6) +#define ADDTO_GAMMA_FLAG (1 << 7) +#define ADDTO_BETA_FLAG (1 << 8) #if MXNET_USE_CUDNN == 1 #include "./cudnn/cudnn_batch_norm-inl.h" @@ -361,33 +364,58 @@ static __global__ void BatchNormalizationBackwardKernel( * momentum + localVariance * (AccReal(1) - momentum); } - if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); + if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) { + const bool grad_write = flags & WRITE_DATA_FLAG; + if (grad_write) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); + } + } + } + } else { + // grad addto + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) += + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) += ScalarConvert::to( + gradOut * gradScale); + } } } } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + if (flags & WRITE_GAMMA_FLAG) + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + else + tensors.gradWeight[plane] += ScalarConvert::to(dotP * invstd); } else { tensors.gradWeight[plane] = DType(0); } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { + if (flags & WRITE_BETA_FLAG) + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + else + tensors.gradBias[plane] += ScalarConvert::to(gradOutputSum); } } @@ -579,12 +607,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx, flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0; if (IsBNWriting(req[batchnorm::kData])) { flags |= WRITE_DATA_FLAG; + } else if (req[batchnorm::kData] == kAddTo) { + flags |= ADDTO_DATA_FLAG; } if (IsBNWriting(req[batchnorm::kGamma])) { flags |= WRITE_GAMMA_FLAG; + } else if (req[batchnorm::kGamma] == kAddTo) { + flags |= ADDTO_GAMMA_FLAG; } if (IsBNWriting(req[batchnorm::kBeta])) { flags |= WRITE_BETA_FLAG; + } else if (req[batchnorm::kBeta] == kAddTo) { + flags |= ADDTO_BETA_FLAG; } return flags; } From e3bb53c44ef91d6b4e35daba75f74f54dd4dd5d4 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 6 Jun 2020 21:30:22 +0800 Subject: [PATCH 11/23] fix lint --- src/operator/nn/batch_norm.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 2f990eb7b617..0875f05e669d 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -400,7 +400,8 @@ static __global__ void BatchNormalizationBackwardKernel( } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { if (flags & WRITE_GAMMA_FLAG) tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); @@ -411,7 +412,8 @@ static __global__ void BatchNormalizationBackwardKernel( } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { if (flags & WRITE_BETA_FLAG) tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); else From ba02f4f1bac9768f767ecd9f92e93923f6ee8691 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 00:38:47 +0800 Subject: [PATCH 12/23] fix cudnn bn --- src/operator/nn/batch_norm-inl.h | 10 ---------- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 15 +++++++++++++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 9656784eeb4c..485b3b33f6a8 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -260,16 +260,6 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); - // check req - bool req_write_existed = false, req_addto_existed = false; - for (const OpReqType &r : req) { - if (IsBNWriting(r)) req_write_existed = true; - else if (r == kAddTo) req_addto_existed = true; - } - CHECK_EQ(req_write_existed && req_addto_existed, false) \ - << "BatchNorm does not support `grad_req` of two inputs" - "are `write` and `add` simultaneously"; - std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index f562f38d698c..1101858b2cc6 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -222,13 +222,24 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; + bool grad_add_gamma_beta = req[cudnnbatchnorm::kGamma] || + req[cudnnbatchnorm::kBeta]; + if (grad_add_gamma_beta) { + if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { + dgamma = 0.f; + } + if (IsBNWriting(req[cudnnbatchnorm::kBeta])) { + dbeta = 0.f; + } + } + CUDNN_CALL(cudnnBatchNormalizationBackward( s->dnn_handle_, mode, &a, - &b, + req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, &a, - req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b, + grad_add_gamma_beta ? &b_add : &b, // gamma and beta io_desc_, x.dptr_, io_desc_, From fcf088e531124c831367d105d7e0fe2963725879 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 00:40:30 +0800 Subject: [PATCH 13/23] fix flag --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 1101858b2cc6..82006a5be0cd 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -222,8 +222,8 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; - bool grad_add_gamma_beta = req[cudnnbatchnorm::kGamma] || - req[cudnnbatchnorm::kBeta]; + bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) || + (req[cudnnbatchnorm::kBeta] == kAddTo); if (grad_add_gamma_beta) { if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { dgamma = 0.f; From a04df8d4fa76464f095ee130f767600566124ef0 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 00:54:08 +0800 Subject: [PATCH 14/23] fix lint --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 82006a5be0cd..13db44d518b3 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -239,7 +239,7 @@ class CuDNNBatchNormOp { &a, req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, &a, - grad_add_gamma_beta ? &b_add : &b, // gamma and beta + grad_add_gamma_beta ? &b_add : &b, // gamma and beta io_desc_, x.dptr_, io_desc_, From ad979c729c6bbba052690d20b8761c2b9e344dd0 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 05:55:44 +0800 Subject: [PATCH 15/23] fix testcase --- src/operator/nn/batch_norm.cc | 4 +++- tests/python/unittest/test_operator.py | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 2b2b11e553fc..d4b03ae3fc17 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -345,7 +345,9 @@ void BatchNormBackwardImpl(mshadow::Stream *, if (!param_.fix_gamma) { KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); } else { - gradWeightData[channel] = AccReal(0); + if (IsBNWriting(req[batchnorm::kGamma])) { + gradWeightData[channel] = AccReal(0); + } } KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 8cf657a481e7..49c038a1c187 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1937,10 +1937,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, else: adX, adW, adb = dX, dW, db - if grad_req == 'add': - atol, rtol = 5e-2, 5e-2 - else: - atol, rtol = 1e-2, 1e-2 + atol, rtol = 1e-2, 5e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), From 91815315f7caea620031d12c8ab1de374ff0061a Mon Sep 17 00:00:00 2001 From: wkcn Date: Sun, 7 Jun 2020 09:57:26 +0800 Subject: [PATCH 16/23] fix --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 49c038a1c187..d65ae1d28fba 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1937,7 +1937,7 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, else: adX, adW, adb = dX, dW, db - atol, rtol = 1e-2, 5e-2 + atol, rtol = 5e-2, 5e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), From 1063d9a89ed5d70e25c154117ddec81c9ce4ade4 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 10:31:17 +0800 Subject: [PATCH 17/23] use @pytest.mark.parametrize --- tests/python/unittest/test_operator.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d65ae1d28fba..d11497aa7bcb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1838,7 +1838,14 @@ def check_batchnorm_training(stype): @xfail_when_nonstandard_decimal_separator @with_seed() -def test_batchnorm(): +@pytest.mark.parametrize('op', [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]) +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), + (24, 8, 4, 5), (24, 5, 6, 4, 5)]) +@pytest.mark.parametrize('fix_gamma', [False, True]) +@pytest.mark.parametrize('grad_req', ['write', 'add']) +@pytest.mark.parametrize('cudnn_off', [False, True]) +@pytest.mark.parametrize('output_mean_var', [False, True]) +def test_batchnorm(op, shape, fix_gamma, grad_req, cudnn_off, output_mean_var): momentum = 0.9 epsilon = 1e-5 @@ -1970,17 +1977,9 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), - (24, 8, 4, 4), (24, 5, 6, 4, 4)]: - for axis in range(len(shape)): - for fix_gamma in [False, True]: - for grad_req in ['write', 'add']: - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var) + for axis in range(len(shape)): + _test_batchnorm_impl(op, shape, axis, fix_gamma, + grad_req, cudnn_off, output_mean_var) @with_seed() def test_groupnorm(): From 096880565a49f39f85c85cc771252588b4f552a8 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 11:02:13 +0800 Subject: [PATCH 18/23] combination --- tests/python/unittest/test_operator.py | 72 +++++++++++++------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d11497aa7bcb..2a43317dd260 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1842,20 +1842,18 @@ def check_batchnorm_training(stype): @pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), (24, 8, 4, 5), (24, 5, 6, 4, 5)]) @pytest.mark.parametrize('fix_gamma', [False, True]) -@pytest.mark.parametrize('grad_req', ['write', 'add']) +@pytest.mark.parametrize('data_grad_req', ['null', 'write', 'add']) +@pytest.mark.parametrize('gamma_grad_req', ['null', 'write', 'add']) +@pytest.mark.parametrize('beta_grad_req', ['null', 'write', 'add']) @pytest.mark.parametrize('cudnn_off', [False, True]) @pytest.mark.parametrize('output_mean_var', [False, True]) -def test_batchnorm(op, shape, fix_gamma, grad_req, cudnn_off, output_mean_var): +def test_batchnorm(op, shape, fix_gamma, + data_grad_req, gamma_grad_req, beta_grad_req, + cudnn_off, output_mean_var): momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var): - print(str((op, shape, axis, fix_gamma, - grad_req, - cudnn_off, output_mean_var))) - + def _test_batchnorm_impl(axis): kwargs = dict(output_mean_var=output_mean_var) if op == mx.nd.contrib.SyncBatchNorm: if axis != 1: @@ -1870,12 +1868,12 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, if not fix_gamma: bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad(grad_req=grad_req) + bn_gamma.attach_grad(grad_req=gamma_grad_req) else: bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad(grad_req=grad_req) + bn_beta.attach_grad(grad_req=beta_grad_req) bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1886,12 +1884,15 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) + data.attach_grad(grad_req=data_grad_req) adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' for _ in range(num_iters): - if grad_req != 'add': + if data_grad_req != 'add': data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=grad_req) + data.attach_grad(grad_req=data_grad_req) ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): output = op(data, bn_gamma, bn_beta, @@ -1900,7 +1901,8 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, fix_gamma=fix_gamma, **kwargs) if output_mean_var: output, output_mean, output_std = output - output.backward(ograd) + if is_train: + output.backward(ograd) mx.nd.waitall() data_mean = data.mean( @@ -1937,12 +1939,9 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) - if grad_req == 'add': - adX += dX - adW += dW - adb += db - else: - adX, adW, adb = dX, dW, db + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db atol, rtol = 5e-2, 5e-2 @@ -1961,25 +1960,28 @@ def _test_batchnorm_impl(op, shape, axis, fix_gamma, atol=atol, rtol=rtol) assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - assert_almost_equal(data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) if not fix_gamma: - assert_almost_equal( - bn_gamma.grad.asnumpy(), adW.asnumpy(), - atol=atol, rtol=rtol) + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) else: assert((bn_gamma.asnumpy() == 1).all()) - assert_almost_equal( - bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) for axis in range(len(shape)): - _test_batchnorm_impl(op, shape, axis, fix_gamma, - grad_req, cudnn_off, output_mean_var) + _test_batchnorm_impl(axis) @with_seed() def test_groupnorm(): From 60a2076226243786e5921e68b04a347a4c50fa1d Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 11:15:52 +0800 Subject: [PATCH 19/23] remove redundant test in batchnorm --- tests/python/unittest/test_operator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2a43317dd260..31f5ba065a08 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1850,6 +1850,9 @@ def check_batchnorm_training(stype): def test_batchnorm(op, shape, fix_gamma, data_grad_req, gamma_grad_req, beta_grad_req, cudnn_off, output_mean_var): + if fix_gamma and gamma_grad_req != 'null': + # skip redundant test when fixing gamma + return momentum = 0.9 epsilon = 1e-5 From f768bf028305fedb8640b12b7e42b122d565272e Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 11:44:55 +0800 Subject: [PATCH 20/23] npx.batch_norm test --- tests/python/unittest/test_numpy_op.py | 162 +++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 45b6a9c7c217..60851d86c316 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1596,6 +1596,168 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, transpose_b=transpose_b)) +@with_seed() +@use_np +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), + (24, 8, 4, 5), (24, 5, 6, 4, 5)]) +@pytest.mark.parametrize('fix_gamma', [False, True]) +@pytest.mark.parametrize('data_grad_req', ['null', 'write', 'add']) +@pytest.mark.parametrize('gamma_grad_req', ['null', 'write', 'add']) +@pytest.mark.parametrize('beta_grad_req', ['null', 'write', 'add']) +@pytest.mark.parametrize('cudnn_off', [False, True]) +@pytest.mark.parametrize('output_mean_var', [False, True]) +def test_npx_batch_norm(shape, fix_gamma, + data_grad_req, gamma_grad_req, beta_grad_req, + cudnn_off, output_mean_var): + shape = (24, 2) + fix_gamma = False + data_grad_req = 'write' + gamma_grad_req = 'write' + beta_grad_req = 'write' + cudnn_off = False + output_mean_var = False + + momentum = 0.9 + epsilon = 1e-5 + if fix_gamma and gamma_grad_req != 'null': + # skip redundant test when fixing gamma + return + class TestBatchNorm(HybridBlock): + def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs): + super().__init__() + self.eps = eps + self.fix_gamma = fix_gamma + self.momentum = momentum + self.kwargs = kwargs + def hybrid_forward(self, F, data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var): + op = F.npx.batch_norm + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, + momentum=self.momentum, eps=self.eps, + fix_gamma=self.fix_gamma, **self.kwargs) + return output + + def _test_batchnorm_impl(axis): + kwargs = dict(output_mean_var=output_mean_var) + kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) + op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs) + nch = shape[axis] + + if not fix_gamma: + bn_gamma = np.random.uniform(size=(nch,)) + bn_gamma.attach_grad(grad_req=gamma_grad_req) + else: + bn_gamma = np.ones((nch,)) + + bn_beta = np.random.uniform(size=(nch,)) + bn_beta.attach_grad(grad_req=beta_grad_req) + + bn_running_mean = np.zeros(nch) + bn_running_var = np.ones(nch) + + running_mean = np.zeros(nch) + running_var = np.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + expand_shape = tuple(expand_shape) + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' + for _ in range(num_iters): + if data_grad_req != 'add': + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + ograd = np.random.uniform(size=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var) + if output_mean_var: + output, output_mean, output_std = output + if is_train: + output.backward(ograd) + mx.nd.waitall() + + assert 0 <= axis < data.ndim + reduce_axis = tuple(i for i in range(data.ndim) if i != axis) + assert len(reduce_axis) == data.ndim - 1 + data_mean = data.mean( + axis=reduce_axis, keepdims=True) + data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis, + keepdims=True) + + target_output = (data - data_mean) / \ + np.sqrt(data_var + epsilon) * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / np.sqrt(data_var + epsilon) + nx = xsm * nd + m = _np.prod(shape) / shape[axis] + dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True, + ) * (-0.5) * np.power(nd, 3) + dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \ + dvar * xsm.mean(axis=reduce_axis, keepdims=True, + ) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = (ograd * nx).sum(axis=reduce_axis) + db = ograd.sum(axis=reduce_axis) + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db + + atol, rtol = 5e-2, 5e-2 + + if output_mean_var: + assert_almost_equal(output_mean.asnumpy(), + data_mean_flat.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output_std.asnumpy(), + (1.0 / (data_var_flat + + epsilon).sqrt()).asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + + for axis in range(len(shape)): + _test_batchnorm_impl(axis) + + @with_seed() @use_np def test_npx_softmax(): From b44d8c3272f7b1fcf60dde8ce833a84b8aeb6cd2 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 17:14:08 +0800 Subject: [PATCH 21/23] try to fix test --- tests/python/unittest/test_operator.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 31f5ba065a08..097c163d1aa4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1838,7 +1838,7 @@ def check_batchnorm_training(stype): @xfail_when_nonstandard_decimal_separator @with_seed() -@pytest.mark.parametrize('op', [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]) +@pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm']) @pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), (24, 8, 4, 5), (24, 5, 6, 4, 5)]) @pytest.mark.parametrize('fix_gamma', [False, True]) @@ -1847,18 +1847,24 @@ def check_batchnorm_training(stype): @pytest.mark.parametrize('beta_grad_req', ['null', 'write', 'add']) @pytest.mark.parametrize('cudnn_off', [False, True]) @pytest.mark.parametrize('output_mean_var', [False, True]) -def test_batchnorm(op, shape, fix_gamma, +def test_batchnorm(op_name, shape, fix_gamma, data_grad_req, gamma_grad_req, beta_grad_req, cudnn_off, output_mean_var): if fix_gamma and gamma_grad_req != 'null': # skip redundant test when fixing gamma return + if op_name == 'BatchNorm': + op = mx.nd.BatchNorm + elif op_name == 'SyncBatchNorm': + op = mx.nd.contrib.SyncBatchNorm + else: + raise ValueError(f'Not supported {op_name}') momentum = 0.9 epsilon = 1e-5 def _test_batchnorm_impl(axis): kwargs = dict(output_mean_var=output_mean_var) - if op == mx.nd.contrib.SyncBatchNorm: + if op_name == 'SyncBatchNorm': if axis != 1: return key = str(op) + str(shape) + str(axis) From 8a9bb35075cfada584940e56e3df0ffeb81a4697 Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 20:02:49 +0800 Subject: [PATCH 22/23] reduce the number of tests for batchnorm --- tests/python/unittest/test_numpy_op.py | 35 ++++++++++---------------- tests/python/unittest/test_operator.py | 27 ++++++++++---------- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 60851d86c316..76224f74d39f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1598,30 +1598,14 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, @with_seed() @use_np -@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]) @pytest.mark.parametrize('fix_gamma', [False, True]) -@pytest.mark.parametrize('data_grad_req', ['null', 'write', 'add']) -@pytest.mark.parametrize('gamma_grad_req', ['null', 'write', 'add']) -@pytest.mark.parametrize('beta_grad_req', ['null', 'write', 'add']) @pytest.mark.parametrize('cudnn_off', [False, True]) @pytest.mark.parametrize('output_mean_var', [False, True]) -def test_npx_batch_norm(shape, fix_gamma, - data_grad_req, gamma_grad_req, beta_grad_req, - cudnn_off, output_mean_var): - shape = (24, 2) - fix_gamma = False - data_grad_req = 'write' - gamma_grad_req = 'write' - beta_grad_req = 'write' - cudnn_off = False - output_mean_var = False - +def test_npx_batch_norm(shape, fix_gamma, cudnn_off, output_mean_var): momentum = 0.9 epsilon = 1e-5 - if fix_gamma and gamma_grad_req != 'null': - # skip redundant test when fixing gamma - return class TestBatchNorm(HybridBlock): def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs): super().__init__() @@ -1638,7 +1622,8 @@ def hybrid_forward(self, F, data, bn_gamma, bn_beta, fix_gamma=self.fix_gamma, **self.kwargs) return output - def _test_batchnorm_impl(axis): + def _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req): kwargs = dict(output_mean_var=output_mean_var) kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs) @@ -1754,9 +1739,15 @@ def _test_batchnorm_impl(axis): assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - for axis in range(len(shape)): - _test_batchnorm_impl(axis) - + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) @with_seed() @use_np diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 097c163d1aa4..c4370e2732fc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1839,20 +1839,12 @@ def check_batchnorm_training(stype): @xfail_when_nonstandard_decimal_separator @with_seed() @pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm']) -@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 4, 4, 5), +@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]) @pytest.mark.parametrize('fix_gamma', [False, True]) -@pytest.mark.parametrize('data_grad_req', ['null', 'write', 'add']) -@pytest.mark.parametrize('gamma_grad_req', ['null', 'write', 'add']) -@pytest.mark.parametrize('beta_grad_req', ['null', 'write', 'add']) @pytest.mark.parametrize('cudnn_off', [False, True]) @pytest.mark.parametrize('output_mean_var', [False, True]) -def test_batchnorm(op_name, shape, fix_gamma, - data_grad_req, gamma_grad_req, beta_grad_req, - cudnn_off, output_mean_var): - if fix_gamma and gamma_grad_req != 'null': - # skip redundant test when fixing gamma - return +def test_batchnorm(op_name, shape, fix_gamma, cudnn_off, output_mean_var): if op_name == 'BatchNorm': op = mx.nd.BatchNorm elif op_name == 'SyncBatchNorm': @@ -1862,7 +1854,8 @@ def test_batchnorm(op_name, shape, fix_gamma, momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(axis): + def _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req): kwargs = dict(output_mean_var=output_mean_var) if op_name == 'SyncBatchNorm': if axis != 1: @@ -1989,8 +1982,16 @@ def _test_batchnorm_impl(axis): assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - for axis in range(len(shape)): - _test_batchnorm_impl(axis) + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) + @with_seed() def test_groupnorm(): From bd06b8f3eef14463b79d7cd77d2292c35144a07d Mon Sep 17 00:00:00 2001 From: wkcn Date: Mon, 8 Jun 2020 20:18:30 +0800 Subject: [PATCH 23/23] fix --- tests/python/unittest/test_numpy_op.py | 22 +++++++++++----------- tests/python/unittest/test_operator.py | 18 +++++++++--------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 76224f74d39f..550b6dd42d32 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1714,8 +1714,8 @@ def _test_batchnorm_impl(axis, data_mean_flat.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(output_std.asnumpy(), - (1.0 / (data_var_flat + - epsilon).sqrt()).asnumpy(), + (1.0 / np.sqrt(data_var_flat + + epsilon)).asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) @@ -1739,15 +1739,15 @@ def _test_batchnorm_impl(axis, assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] - for data_grad_req in grad_reqs: - for gamma_grad_req in grad_reqs: - if fix_gamma and gamma_grad_req != 'null': - continue - for beta_grad_req in grad_reqs: - for axis in range(len(shape)): - _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req) + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) @with_seed() @use_np diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c4370e2732fc..c4385880df64 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1982,15 +1982,15 @@ def _test_batchnorm_impl(axis, assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] - for data_grad_req in grad_reqs: - for gamma_grad_req in grad_reqs: - if fix_gamma and gamma_grad_req != 'null': - continue - for beta_grad_req in grad_reqs: - for axis in range(len(shape)): - _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req) + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl(axis, + data_grad_req, gamma_grad_req, beta_grad_req) @with_seed()