diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 485b3b33f6a8..17a16db5adcd 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -259,7 +259,6 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, const std::vector &outputs) { CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); - std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index fc65476f6d50..3214e3b9b9ac 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -85,31 +85,6 @@ static inline void ForEachFast(const BNTensor3 &in_data, } } -template -static inline void ForEachFast(const BNTensor3 &in_data, - const BNTensor3 &in_data2, - const BNTensor3 &out_data, - const size_t channel, - OnData onData) { - const size_t num = in_data.OuterSize(); - const size_t matrixSize = in_data.InnerSize(); - const size_t skipLength = in_data.SkipLengthToNextSameChannelData(); - const size_t startOffset = in_data.StartOffset(channel); - - DType1 *data = in_data.dptr_ + startOffset; - DType2 *data2 = in_data2.dptr_ + startOffset; - DType3 *odata = out_data.dptr_ + startOffset; - - for (size_t outer = 0; outer < num; ++outer) { - for (size_t i = 0; i < matrixSize; ++i) { - onData(data++, data2++, odata++); - } - data += skipLength; - data2 += skipLength; - odata += skipLength; - } -} - } // namespace batchnorm /*! \brief Forward CPU */ @@ -289,7 +264,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, dotp += (*thisInputData - mean) * (*gradOut_data); }); - if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input + if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input if (is_train_and_not_global_stats) { // when in training mode // Q(X) = X - E[x] ; i.e. input centered to zero mean @@ -298,60 +273,44 @@ void BatchNormBackwardImpl(mshadow::Stream *, // projection of gradOutput on to output scaled by std const AccReal k = dotp * invstd * invstd / itemCount; + ForEachFast(inputData, gradIn, static_cast(channel), + [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { + *gradIn_data = (*inputDataPtr - mean) * k; + }); + const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - if (req[batchnorm::kData] != kAddTo) { - ForEachFast(inputData, gradIn, static_cast(channel), - [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { - *gradIn_data = (*inputDataPtr - mean) * k; - }); - - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; - }); - } else { - ForEachFast(inputData, gradOut, gradIn, static_cast(channel), - [&mean, &k, iw, gradMean](const DType *inputDataPtr, - const DType *gradOut_data, - DType *gradIn_data) { - DType normal_val = (*inputDataPtr - mean) * k; - *gradIn_data += (*gradOut_data - gradMean - - normal_val) * iw; - }); - } + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; + }); } else { // when in evaluation mode // Q(X) = X - running_mean ; i.e. input centered to zero mean // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - if (req[batchnorm::kData] != kAddTo) { - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = *gradOut_data * iw; - }); - } else { - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data += *gradOut_data * iw; - }); - } + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = *gradOut_data * iw; + }); } } // May want to make this a param eventually const AccReal scale = 1.0f; - if (!param_.fix_gamma) { - KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); - } else { - if (IsBNWriting(req[batchnorm::kGamma])) { + if (IsBNWriting(req[batchnorm::kGamma])) { + if (!param_.fix_gamma) { + gradWeightData[channel] = scale * dotp * invstd; + } else { gradWeightData[channel] = AccReal(0); } } - KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); + if (IsBNWriting(req[batchnorm::kBeta])) { + gradBiasData[channel] = scale * sumGradOut; + } } } diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 7b36d25e7496..be9309c8bfb1 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -34,9 +34,6 @@ #define FIX_GAMMA_FLAG 8 #define IS_TRAINING_FLAG 16 #define USE_GLOBAL_STATS_FLAG 32 -#define ADDTO_DATA_FLAG (1 << 6) -#define ADDTO_GAMMA_FLAG (1 << 7) -#define ADDTO_BETA_FLAG (1 << 8) #if MXNET_USE_CUDNN == 1 #include "./cudnn/cudnn_batch_norm-inl.h" @@ -365,60 +362,33 @@ static __global__ void BatchNormalizationBackwardKernel( * momentum + localVariance * (AccReal(1) - momentum); } - if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) { - const bool grad_write = flags & WRITE_DATA_FLAG; - if (grad_write) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); - } - } - } - } else { - // grad addto - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) += - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) += ScalarConvert::to( - gradOut * gradScale); - } + if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); } } } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && - (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { - if (flags & WRITE_GAMMA_FLAG) - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); - else - tensors.gradWeight[plane] += ScalarConvert::to(dotP * invstd); + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); } else { tensors.gradWeight[plane] = DType(0); } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && - (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { - if (flags & WRITE_BETA_FLAG) - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); - else - tensors.gradBias[plane] += ScalarConvert::to(gradOutputSum); + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); } } @@ -615,18 +585,12 @@ static inline uint32_t SetupFlags(const OpContext &ctx, flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0; if (IsBNWriting(req[batchnorm::kData])) { flags |= WRITE_DATA_FLAG; - } else if (req[batchnorm::kData] == kAddTo) { - flags |= ADDTO_DATA_FLAG; } if (IsBNWriting(req[batchnorm::kGamma])) { flags |= WRITE_GAMMA_FLAG; - } else if (req[batchnorm::kGamma] == kAddTo) { - flags |= ADDTO_GAMMA_FLAG; } if (IsBNWriting(req[batchnorm::kBeta])) { flags |= WRITE_BETA_FLAG; - } else if (req[batchnorm::kBeta] == kAddTo) { - flags |= ADDTO_BETA_FLAG; } return flags; } diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 5dad073c2815..3fc91196708c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -208,24 +208,13 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; - bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) || - (req[cudnnbatchnorm::kBeta] == kAddTo); - if (grad_add_gamma_beta) { - if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { - dgamma = 0.f; - } - if (IsBNWriting(req[cudnnbatchnorm::kBeta])) { - dbeta = 0.f; - } - } - CUDNN_CALL(cudnnBatchNormalizationBackward( s->dnn_handle_, mode, &a, - req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, + &b, &a, - grad_add_gamma_beta ? &b_add : &b, // gamma and beta + req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, io_desc_, x.dptr_, io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 2063aa4d3472..26637c7c0b65 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -317,8 +317,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), - bwd.pd.diff_src_desc(), req[batchnorm::kData]); + auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -338,7 +337,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, mkldnn_args_map_t net_args; net_args[MKLDNN_ARG_SRC] = *data_mem; - net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; @@ -363,46 +362,26 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, } net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + MKLDNNStream::Get()->Submit(); } - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - CommitOutput(gradIn, gradi_mem); - MKLDNNStream::Get()->Submit(); // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); - DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr(); - DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr(); - - // the gradient of gamma - if (!param.fix_gamma) { - if (req[batchnorm::kGamma] != kNullOp) { - if (req[batchnorm::kGamma] != kAddTo) { - memcpy(w_grad_1, gw_buf, copy_size); - } else { - for (int i = 0; i < channels_; i++) { - w_grad_1[i] += gw_buf[i]; - } - } - } - } else { - for (int i = 0; i < channels_; i++) { + for (int i = 0; i < channels_; i++) { + if (!param.fix_gamma) + (in_grad[1].data().dptr())[i] = gw_buf[i]; + else (in_grad[1].data().dptr())[i] = 0.0f; - } } - // the gradient of beta - if (req[batchnorm::kBeta] != kNullOp) { - if (req[batchnorm::kBeta] != kAddTo) { - memcpy(w_grad_2, &gw_buf[channels_], copy_size); - } else { - DType *grad_beta = &gw_buf[channels_]; - for (int i = 0; i < channels_; i++) { - w_grad_2[i] += grad_beta[i]; - } - } + for (int i = 0; i < channels_; i++) { + (in_grad[2].data().dptr())[i] = gw_buf[i + channels_]; } } else { LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0652a0597c69..1ff1b6139cce 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1071,198 +1071,6 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, transpose_b=transpose_b)) -@with_seed() -@use_np -@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), - (24, 8, 4, 5), (24, 5, 6, 4, 5)]) -@pytest.mark.parametrize('fix_gamma', [False, True]) -@pytest.mark.parametrize('cudnn_off', [False, True]) -@pytest.mark.parametrize('output_mean_var', [False, True]) -def test_npx_batch_norm(shape, fix_gamma, cudnn_off, output_mean_var): - momentum = 0.9 - epsilon = 1e-5 - class TestBatchNorm(HybridBlock): - def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs): - super().__init__() - self.eps = eps - self.fix_gamma = fix_gamma - self.momentum = momentum - self.kwargs = kwargs - def hybrid_forward(self, F, data, bn_gamma, bn_beta, - bn_running_mean, bn_running_var): - op = F.npx.batch_norm - output = op(data, bn_gamma, bn_beta, - bn_running_mean, bn_running_var, - momentum=self.momentum, eps=self.eps, - fix_gamma=self.fix_gamma, **self.kwargs) - return output - - def _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req): - kwargs = dict(output_mean_var=output_mean_var) - kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) - op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs) - nch = shape[axis] - - if not fix_gamma: - bn_gamma = np.random.uniform(size=(nch,)) - bn_gamma.attach_grad(grad_req=gamma_grad_req) - else: - bn_gamma = np.ones((nch,)) - - bn_beta = np.random.uniform(size=(nch,)) - bn_beta.attach_grad(grad_req=beta_grad_req) - - bn_running_mean = np.zeros(nch) - bn_running_var = np.ones(nch) - - running_mean = np.zeros(nch) - running_var = np.ones(nch) - num_iters = 10 - expand_shape = [1] * len(shape) - expand_shape[axis] = shape[axis] - expand_shape = tuple(expand_shape) - data = np.random.uniform(size=shape) - data.attach_grad(grad_req=data_grad_req) - adX, adW, adb = 0, 0, 0 - is_train = data_grad_req != 'null' or \ - (not fix_gamma and gamma_grad_req != 'null') or \ - beta_grad_req != 'null' - for _ in range(num_iters): - if data_grad_req != 'add': - data = np.random.uniform(size=shape) - data.attach_grad(grad_req=data_grad_req) - ograd = np.random.uniform(size=shape) - with mx.autograd.record(): - output = op(data, bn_gamma, bn_beta, - bn_running_mean, bn_running_var) - if output_mean_var: - output, output_mean, output_std = output - if is_train: - output.backward(ograd) - mx.nd.waitall() - - assert 0 <= axis < data.ndim - reduce_axis = tuple(i for i in range(data.ndim) if i != axis) - assert len(reduce_axis) == data.ndim - 1 - data_mean = data.mean( - axis=reduce_axis, keepdims=True) - data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis, - keepdims=True) - - target_output = (data - data_mean) / \ - np.sqrt(data_var + epsilon) * \ - bn_gamma.reshape(expand_shape) + \ - bn_beta.reshape(expand_shape) - - # squeeze data_mean and data_var - data_mean_flat = data_mean.squeeze() - data_var_flat = data_var.squeeze() - - running_mean = running_mean * momentum + \ - data_mean_flat * (1 - momentum) - running_var = running_var * momentum + \ - data_var_flat * (1 - momentum) - - W = bn_gamma.reshape(expand_shape) - dnx = ograd * W - xsm = data - data_mean - nd = 1.0 / np.sqrt(data_var + epsilon) - nx = xsm * nd - m = _np.prod(shape) / shape[axis] - dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True, - ) * (-0.5) * np.power(nd, 3) - dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \ - dvar * xsm.mean(axis=reduce_axis, keepdims=True, - ) * 2.0 - dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) - dW = (ograd * nx).sum(axis=reduce_axis) - db = ograd.sum(axis=reduce_axis) - adX = dX if data_grad_req != 'add' else adX + dX - adW = dW if gamma_grad_req != 'add' else adW + dW - adb = db if beta_grad_req != 'add' else adb + db - - atol, rtol = 5e-2, 5e-2 - - if output_mean_var: - assert_almost_equal(output_mean.asnumpy(), - data_mean_flat.asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(output_std.asnumpy(), - (1.0 / np.sqrt(data_var_flat + - epsilon)).asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(output.asnumpy(), target_output.asnumpy(), - atol=atol, rtol=rtol) - if is_train: - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - if data_grad_req != 'null': - assert_almost_equal(data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) - if not fix_gamma: - if gamma_grad_req != 'null': - assert_almost_equal( - bn_gamma.grad.asnumpy(), adW.asnumpy(), - atol=atol, rtol=rtol) - else: - assert((bn_gamma.asnumpy() == 1).all()) - if beta_grad_req != 'null': - assert_almost_equal( - bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - - grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] - for data_grad_req in grad_reqs: - for gamma_grad_req in grad_reqs: - if fix_gamma and gamma_grad_req != 'null': - continue - for beta_grad_req in grad_reqs: - for axis in range(len(shape)): - _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req) - -@with_seed() -@use_np -def test_npx_softmax(): - class TestSoftmax(HybridBlock): - def __init__(self, axis): - super(TestSoftmax, self).__init__() - self._axis = axis - - def hybrid_forward(self, F, a): - return F.npx.softmax(a, axis=axis) - - def np_softmax(x, axis=-1): - if (x.shape[axis] == 0): - return _np.sum(x, axis=axis, keepdims=True) - x = x - _np.max(x, axis=axis, keepdims=True) - x = _np.exp(x) - x /= _np.sum(x, axis=axis, keepdims=True) - return x - - # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py - for hybridize in [True, False]: - for shape in [(3, 0, 4), (0, 0)]: - mx_a = np.random.uniform(size=shape) - mx_a.attach_grad() - for axis in range(-len(shape), len(shape)): - test_softmax = TestSoftmax(axis) - if hybridize: - test_softmax.hybridize() - - with mx.autograd.record(): - mx_out = test_softmax(mx_a) - - np_out = np_softmax(mx_a.asnumpy(), axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True) - - mx_out.backward() - assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5) - - @with_seed() @use_np def test_npi_boolean_assign(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b7650d5ac5c1..9ae35f15748a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1822,26 +1822,15 @@ def check_batchnorm_training(stype): @with_seed() -@pytest.mark.parametrize('op_name', ['BatchNorm', 'SyncBatchNorm']) -@pytest.mark.parametrize('shape', [(24, 2), (24, 3, 4), - (24, 8, 4, 5), (24, 5, 6, 4, 5)]) -@pytest.mark.parametrize('fix_gamma', [False, True]) -@pytest.mark.parametrize('cudnn_off', [False, True]) -@pytest.mark.parametrize('output_mean_var', [False, True]) -def test_batchnorm(op_name, shape, fix_gamma, cudnn_off, output_mean_var): - if op_name == 'BatchNorm': - op = mx.nd.BatchNorm - elif op_name == 'SyncBatchNorm': - op = mx.nd.contrib.SyncBatchNorm - else: - raise ValueError(f'Not supported {op_name}') +def test_batchnorm(): momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req): + def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): + print(str((op, shape, axis, cudnn_off))) + kwargs = dict(output_mean_var=output_mean_var) - if op_name == 'SyncBatchNorm': + if op == mx.nd.contrib.SyncBatchNorm: if axis != 1: return key = str(op) + str(shape) + str(axis) @@ -1852,14 +1841,11 @@ def _test_batchnorm_impl(axis, kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) nch = shape[axis] - if not fix_gamma: - bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad(grad_req=gamma_grad_req) - else: - bn_gamma = mx.nd.ones(shape=(nch,)) + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad(grad_req=beta_grad_req) + bn_beta.attach_grad() bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1869,26 +1855,18 @@ def _test_batchnorm_impl(axis, num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] - data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=data_grad_req) - adX, adW, adb = 0, 0, 0 - is_train = data_grad_req != 'null' or \ - (not fix_gamma and gamma_grad_req != 'null') or \ - beta_grad_req != 'null' for _ in range(num_iters): - if data_grad_req != 'add': - data = mx.nd.random.uniform(shape=shape) - data.attach_grad(grad_req=data_grad_req) + data = mx.nd.random.uniform(shape=shape) + data.attach_grad() ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, - fix_gamma=fix_gamma, **kwargs) + fix_gamma=False, **kwargs) if output_mean_var: output, output_mean, output_std = output - if is_train: - output.backward(ograd) + output.backward(ograd) mx.nd.waitall() data_mean = data.mean( @@ -1925,11 +1903,9 @@ def _test_batchnorm_impl(axis, dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) - adX = dX if data_grad_req != 'add' else adX + dX - adW = dW if gamma_grad_req != 'add' else adW + dW - adb = db if beta_grad_req != 'add' else adb + db - atol, rtol = 5e-2, 5e-2 + atol = 1e-2 + rtol = 1e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), @@ -1946,35 +1922,25 @@ def _test_batchnorm_impl(axis, atol=atol, rtol=rtol) assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) - if is_train: - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - if data_grad_req != 'null': - assert_almost_equal(data.grad.asnumpy(), - adX.asnumpy(), atol=atol, rtol=rtol) - if not fix_gamma: - if gamma_grad_req != 'null': - assert_almost_equal( - bn_gamma.grad.asnumpy(), adW.asnumpy(), - atol=atol, rtol=rtol) - else: - assert((bn_gamma.asnumpy() == 1).all()) - if beta_grad_req != 'null': - assert_almost_equal( - bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - - grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] - for data_grad_req in grad_reqs: - for gamma_grad_req in grad_reqs: - if fix_gamma and gamma_grad_req != 'null': - continue - for beta_grad_req in grad_reqs: - for axis in range(len(shape)): - _test_batchnorm_impl(axis, - data_grad_req, gamma_grad_req, beta_grad_req) + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + assert_almost_equal(data.grad.asnumpy(), + dX.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + + for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: + for axis in range(len(shape)): + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, + cudnn_off, output_mean_var) @with_seed()