From a4b85a50bcdd3d6aaadf3a54164611aec4d6f0b7 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Wed, 3 Apr 2019 01:53:02 +0800 Subject: [PATCH] Support SyncBatchNorm5D (#14542) * support SyncBatchNorm5D * fix * update testcase and reformat code * retrigger CI * update test case * test * Retrigger CI * disable cudnn for batchnorm * fix BatchNorm(cudnn) * fix build * Remove a testcase * Update sync_batch_norm-inl.h * update unittest * update unittest * update test * fix test * change atol and rtol * BN(cudnn) 5d * update test * test * Testing * Update batch_norm.cu * test cudnnoff * Update test_operator.py * update BN! : ) --- src/operator/contrib/sync_batch_norm-inl.h | 31 +-- src/operator/nn/batch_norm.cu | 4 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 14 +- tests/python/gpu/test_gluon_gpu.py | 216 ++++++++----------- tests/python/unittest/test_gluon.py | 120 +++++++++++ tests/python/unittest/test_operator.py | 122 +++++++++++ 6 files changed, 355 insertions(+), 152 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index b94416640f55..1e6ab25db0e2 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -69,7 +69,6 @@ struct SyncBatchNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(ndev).set_default(1) .describe("The count of GPU devices"); DMLC_DECLARE_FIELD(key) - .set_default("") .describe("Hash key for synchronization, please set the same hash key for same layer, " "Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`."); } @@ -275,14 +274,18 @@ class SyncBatchNorm : public Operator { static_cast(in_data[syncbatchnorm::kData].shape_.Size()); Tensor data; Tensor out; - if (in_data[syncbatchnorm::kData].ndim() == 2) { + if (in_data[syncbatchnorm::kData].ndim() == 4) { + data = in_data[syncbatchnorm::kData].get(s); + out = out_data[syncbatchnorm::kOut].get(s); + } else { + index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ? + in_data[syncbatchnorm::kData].shape_[1] : 1; + index_t spatial_size = in_data[syncbatchnorm::kData].shape_.ProdShape(2, + in_data[syncbatchnorm::kData].ndim()); Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], - in_data[syncbatchnorm::kData].shape_[1], 1, 1); + num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); - } else { - data = in_data[syncbatchnorm::kData].get(s); - out = out_data[syncbatchnorm::kOut].get(s); } Tensor slope = in_data[syncbatchnorm::kGamma].get(s); Tensor bias = in_data[syncbatchnorm::kBeta].get(s); @@ -354,16 +357,20 @@ class SyncBatchNorm : public Operator { Tensor data, grad, grad_in; const real_t scale = static_cast(out_grad[syncbatchnorm::kOut].shape_[1]) / static_cast(out_grad[syncbatchnorm::kOut].shape_.Size()); - if (in_data[syncbatchnorm::kData].ndim() == 2) { + if (in_data[syncbatchnorm::kData].ndim() == 4) { + data = in_data[syncbatchnorm::kData].get(s); + grad = out_grad[syncbatchnorm::kOut].get(s); + grad_in = in_grad[syncbatchnorm::kData].get(s); + } else { + index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ? + out_grad[syncbatchnorm::kOut].shape_[1] : 1; + index_t spatial_size = out_grad[syncbatchnorm::kOut].shape_.ProdShape(2, + out_grad[syncbatchnorm::kOut].ndim()); Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], - out_grad[syncbatchnorm::kOut].shape_[1], 1, 1); + num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); - } else { - data = in_data[syncbatchnorm::kData].get(s); - grad = out_grad[syncbatchnorm::kOut].get(s); - grad_in = in_grad[syncbatchnorm::kData].get(s); } Tensor mean = out_data[syncbatchnorm::kMean].get(s); diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 1199ec7fcce5..9fb44e8fae81 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -668,7 +668,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 + if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); @@ -697,7 +697,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 + if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index d4b9f84ed2f5..820f8504d74c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -84,7 +84,6 @@ class CuDNNBatchNormOp { } CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo); CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2); - CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4); Init(in_data[cudnnbatchnorm::kData]); Stream *s = ctx.get_stream(); @@ -273,12 +272,15 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - for (int i = 0; i < 4; ++i) { - if (i < in_data.ndim()) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) shape_[i] = in_data.shape_[i]; - } else { - shape_[i] = 1; - } + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 9eeeec749211..1c5a5835e6f9 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -45,6 +45,7 @@ set_default_context(mx.gpu(0)) + def check_rnn_layer(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) with mx.gpu(0): @@ -62,6 +63,7 @@ def check_rnn_layer(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + @with_seed() def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) @@ -89,11 +91,12 @@ def test_lstmp(): batch_size, seq_len = 7, 11 input_size = 5 ctx = mx.gpu(0) - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx) - shapes = {'i2h_weight': (hidden_size*4, input_size), - 'h2h_weight': (hidden_size*4, projection_size), - 'i2h_bias': (hidden_size*4,), - 'h2h_bias': (hidden_size*4,), + lstm_input = mx.nd.uniform( + shape=(seq_len, batch_size, input_size), ctx=ctx) + shapes = {'i2h_weight': (hidden_size * 4, input_size), + 'h2h_weight': (hidden_size * 4, projection_size), + 'i2h_bias': (hidden_size * 4,), + 'h2h_bias': (hidden_size * 4,), 'h2r_weight': (projection_size, hidden_size)} weights = {k: rand_ndarray(v) for k, v in shapes.items()} lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, @@ -107,23 +110,26 @@ def test_lstmp(): layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() for k, v in weights.items(): - layer_params['lstm0_l0_'+k].set_data(v.copy()) - cell_params['lstm0_l0_'+k].set_data(v.copy()) + layer_params['lstm0_l0_' + k].set_data(v.copy()) + cell_params['lstm0_l0_' + k].set_data(v.copy()) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', merge_outputs=True)[0] - assert_almost_equal(layer_output.asnumpy(), cell_output.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(layer_output.asnumpy(), + cell_output.asnumpy(), rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() for k, v in weights.items(): - layer_grad = layer_params['lstm0_l0_'+k].grad() - cell_grad = cell_params['lstm0_l0_'+k].grad() - print('checking gradient for {}'.format('lstm0_l0_'+k)) + layer_grad = layer_params['lstm0_l0_' + k].grad() + cell_grad = cell_params['lstm0_l0_' + k].grad() + print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(), rtol=rtol, atol=atol) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM( + 10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones( + (8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), run_only=True, ctx=ctx) @@ -139,7 +145,8 @@ def test_lstm_clip(): batch_size, seq_len = 32, 80 input_size = 50 clip_min, clip_max, clip_nan = -5, 5, True - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) + lstm_input = mx.nd.uniform( + shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)), mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))] lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, @@ -165,7 +172,8 @@ def test_rnn_layer(): check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM( + 100, num_layers=3, bidirectional=True)) def check_layer_bidirectional(size, in_size, proj_size): @@ -173,8 +181,10 @@ class RefBiLSTM(gluon.Block): def __init__(self, size, proj_size, **kwargs): super(RefBiLSTM, self).__init__(**kwargs) with self.name_scope(): - self._lstm_fwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='l0') - self._lstm_bwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='r0') + self._lstm_fwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False, prefix='l0') + self._lstm_bwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False, prefix='r0') def forward(self, inpt): fwd = self._lstm_fwd(inpt) @@ -184,16 +194,23 @@ def forward(self, inpt): return nd.concat(fwd, bwd, dim=2) weights = {} for d in ['l', 'r']: - weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform( + shape=(size * 4, in_size)) if proj_size: - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, proj_size)) - weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform(shape=(proj_size, size)) + weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform( + shape=(size * 4, proj_size)) + weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform( + shape=(proj_size, size)) else: - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) - weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - - net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True, prefix='lstm_') + weights['lstm_{}0_h2h_weight'.format( + d)] = mx.random.uniform(shape=(size * 4, size)) + weights['lstm_{}0_i2h_bias'.format( + d)] = mx.random.uniform(shape=(size * 4,)) + weights['lstm_{}0_h2h_bias'.format( + d)] = mx.random.uniform(shape=(size * 4,)) + + net = gluon.rnn.LSTM(size, projection_size=proj_size, + bidirectional=True, prefix='lstm_') ref_net = RefBiLSTM(size, proj_size, prefix='lstm_') net.initialize() ref_net.initialize() @@ -201,16 +218,19 @@ def forward(self, inpt): ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) + ref_net_params[k.replace('l0', 'l0l0').replace( + 'r0', 'r0l0')].set_data(weights[k]) data = mx.random.uniform(shape=(11, 10, in_size)) assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy()) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_layer_bidirectional(): check_layer_bidirectional(7, 5, 0) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='7.2.1') def test_layer_bidirectional_proj(): @@ -221,7 +241,8 @@ def test_layer_bidirectional_proj(): @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnn_layer_begin_state_type(): fake_data = nd.random.uniform(shape=(3, 5, 7), dtype='float16') - modeling_layer = gluon.rnn.LSTM(hidden_size=11, num_layers=2, dropout=0.2, bidirectional=True) + modeling_layer = gluon.rnn.LSTM( + hidden_size=11, num_layers=2, dropout=0.2, bidirectional=True) modeling_layer.cast('float16') modeling_layer.initialize() modeling_layer(fake_data) @@ -229,9 +250,10 @@ def test_rnn_layer_begin_state_type(): def test_gluon_ctc_consistency(): loss = mx.gluon.loss.CTCLoss() - data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) - cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) - gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) + data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0) + ).reshape((2, 20, 4)).flip(axis=0) + cpu_label = mx.nd.array([[2, 1, -1, -1], [3, 2, 2, -1]], ctx=mx.cpu(0)) + gpu_label = mx.nd.array([[2, 1, -1, -1], [3, 2, 2, -1]], ctx=mx.gpu(0)) cpu_data = data.copy().as_in_context(mx.cpu(0)) cpu_data.attach_grad() @@ -245,15 +267,17 @@ def test_gluon_ctc_consistency(): l_gpu = loss(gpu_data, gpu_label) l_gpu.backward() - assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(cpu_data.grad.asnumpy(), + gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) @with_seed() def test_global_norm_clip_multi_device(): for check_isfinite in [True, False]: - x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) - x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite) + x1 = mx.nd.ones((3, 3), ctx=mx.gpu(0)) + x2 = mx.nd.ones((4, 4), ctx=mx.cpu(0)) + norm = gluon.utils.clip_global_norm( + [x1, x2], 1.0, check_isfinite=check_isfinite) if check_isfinite: assert norm == 5.0 else: @@ -262,86 +286,6 @@ def test_global_norm_clip_multi_device(): assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5) -def _check_batchnorm_result(input, num_devices=1, cuda=False): - from mxnet.gluon.utils import split_and_load - def _find_bn(module): - if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) - -@with_seed() -def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i - # no need to use SyncBN with 1 gpu - if get_num_devices() < 2: - return - ndev = 2 - # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - - @with_seed() def test_symbol_block_fp16(): # Test case to verify if initializing the SymbolBlock from a model with params @@ -352,10 +296,11 @@ def test_symbol_block_fp16(): tmpfile = os.path.join(tmp, 'resnet34_fp16') ctx = mx.gpu(0) - net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2( + pretrained=True, ctx=ctx, root=tmp) net_fp32.cast('float16') net_fp32.hybridize() - data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx) + data = mx.nd.zeros((1, 3, 224, 224), dtype='float16', ctx=ctx) net_fp32.forward(data) net_fp32.export(tmpfile, 0) @@ -389,7 +334,8 @@ def test_large_models(): # Compute the height (=width) of the square tensor of the given size in bytes def tensor_size(big_tensor_bytes): bytes_per_float = 4 - sz = int(math.sqrt(big_tensor_bytes / largest_num_features / bytes_per_float)) + sz = int(math.sqrt(big_tensor_bytes / + largest_num_features / bytes_per_float)) return (sz // 100) * 100 # The idea is to create models with large tensors of (say) 20% of the total memory. @@ -398,12 +344,13 @@ def tensor_size(big_tensor_bytes): (free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id) start_size = tensor_size(0.20 * total_mem_bytes) num_trials = 10 - sys.stderr.write(' testing global memory of size {} ... '.format(total_mem_bytes)) + sys.stderr.write( + ' testing global memory of size {} ... '.format(total_mem_bytes)) sys.stderr.flush() for i in range(num_trials): sz = start_size - 10 * i - (height, width) = (sz,sz) - sys.stderr.write(" {}x{} ".format(height,width)) + (height, width) = (sz, sz) + sys.stderr.write(" {}x{} ".format(height, width)) sys.stderr.flush() data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width), ctx=ctx, dtype="float32") @@ -411,6 +358,8 @@ def tensor_size(big_tensor_bytes): net(data_in).asnumpy() # isolated execution bulking test function to be invoked with different env var settings + + def _test_bulking_in_process(seed, time_per_iteration): # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused. class Flip(gluon.HybridBlock): @@ -440,7 +389,7 @@ def get_net(num_ops): # time a number of forward() and backward() executions after some warm-up iterations warmups = 1 - for i in range(num_iterations+warmups): + for i in range(num_iterations + warmups): with autograd.record(): if i == warmups: start = time.time() @@ -450,20 +399,22 @@ def get_net(num_ops): time_per_iteration.value = (time.time() - start) / num_iterations + @with_seed() def test_bulking(): # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) - test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] + test_cases = [(0, 0, True), (1, 1, True), (15, 15, False), + (15, 0, True), (0, 15, True), (15, 15, True)] times = {} times_str = '' for seg_sizes in test_cases: # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(_test_bulking_in_process, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, - time_per_iteration): + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0], + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1], + 'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]}, + time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return times[seg_sizes] = time_per_iteration.value @@ -471,21 +422,22 @@ def test_bulking(): '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) - slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) - fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) - fully_bulked_time = times[(15,15,True)] + fastest_non_bulked_time = min( + times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)]) + slowest_half_bulked_time = max(times[(0, 15, True)], times[(15, 0, True)]) + fastest_half_bulked_time = min(times[(0, 15, True)], times[(15, 0, True)]) + fully_bulked_time = times[(15, 15, True)] print(times_str) # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, # slower than both half-bulked times[0,15,True] and times[15,0,True] assert slowest_half_bulked_time < fastest_non_bulked_time, \ 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ - .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) # The fully bulked times[15,15,True] should be faster than both half-bulked runs assert fully_bulked_time < fastest_half_bulked_time, \ 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ - .format(fully_bulked_time - fastest_half_bulked_time, times_str) + .format(fully_bulked_time - fastest_half_bulked_time, times_str) if __name__ == '__main__': diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6af7a5f948e2..61ebc2f6b13b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -583,6 +583,126 @@ def test_batchnorm(): check_layer_forward(layer, (2, 10, 10, 10)) +@with_seed() +def test_sync_batchnorm(): + def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] if input.ndim > 1 else 1 + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm( + in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) + for output in output2], dim=0) + # check bn1 + + momentum = 0.9 + epsilon = 1e-5 + axis = 1 + data = input1 + running_mean = mx.nd.zeros(nch, ctx=data.context) + running_var = mx.nd.ones(nch, ctx=data.context) + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, keepdims=True) + + target_output = (data - data_mean) / (data_var + epsilon).sqrt() + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(output1.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + running_mean.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + running_var.asnumpy(), + atol=atol, rtol=rtol) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output1.asnumpy(), + output2.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=atol, rtol=rtol) + input2grad = mx.nd.concat( + *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), + input2grad.asnumpy(), atol=atol, rtol=rtol) + + cfgs = [(1, False)] + num_gpus = mx.context.num_gpus() + for i in range(1, num_gpus + 1): + cfgs.append((i, True)) + for ndev, cuda in cfgs: + # check with unsync version + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + print(str((ndev, cuda, shape))) + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=shape, + ctx=mx.cpu(0)), + num_devices=ndev, cuda=cuda) + + @with_seed() def test_instancenorm(): layer = nn.InstanceNorm(in_channels=10) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c9498ecb0bd2..845ae113c218 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1598,6 +1598,128 @@ def check_batchnorm_training(stype): check_batchnorm_training('default') +@with_seed() +def test_batchnorm(): + momentum = 0.9 + epsilon = 1e-5 + + def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): + print(str((op, shape, axis, cudnn_off))) + + kwargs = dict(output_mean_var=output_mean_var) + if op == mx.nd.contrib.SyncBatchNorm: + if axis != 1: + return + key = str(op) + str(shape) + str(axis) + kwargs.update(dict(key=key)) + if cudnn_off: + return + else: + kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) + nch = shape[axis] + + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() + + bn_beta = mx.nd.random.uniform(shape=(nch,)) + bn_beta.attach_grad() + + bn_running_mean = mx.nd.zeros(nch) + bn_running_var = mx.nd.ones(nch) + + running_mean = mx.nd.zeros(nch) + running_var = mx.nd.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + for _ in range(num_iters): + data = mx.nd.random.uniform(shape=shape) + data.attach_grad() + ograd = mx.nd.random.uniform(shape=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, + momentum=momentum, eps=epsilon, + fix_gamma=False, **kwargs) + if output_mean_var: + output, output_mean, output_std = output + output.backward(ograd) + mx.nd.waitall() + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, + keepdims=True) + + target_output = (data - data_mean) / \ + (data_var + epsilon).sqrt() * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / mx.nd.sqrt(data_var + epsilon) + nx = xsm * nd + m = np.prod(shape) / shape[axis] + dvar = (dnx * xsm).sum(axis=axis, keepdims=True, + exclude=True) * (-0.5) * mx.nd.power(nd, 3) + dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ + dvar * xsm.mean(axis=axis, keepdims=True, + exclude=True) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = (ograd * nx).sum(axis=axis, exclude=True) + db = ograd.sum(axis=axis, exclude=True) + + atol = 1e-2 + rtol = 1e-2 + + if output_mean_var: + assert_almost_equal(output_mean.asnumpy(), + data_mean_flat.asnumpy(), + atol=atol, rtol=rtol) + if op != mx.nd.contrib.SyncBatchNorm: + assert_almost_equal(output_std.asnumpy(), + (1.0 / (data_var_flat + + epsilon).sqrt()).asnumpy(), + atol=atol, rtol=rtol) + else: + assert_almost_equal(output_std.asnumpy(), + data_var_flat.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + assert_almost_equal(data.grad.asnumpy(), + dX.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + + for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + for axis in range(len(shape)): + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, + cudnn_off, output_mean_var) + + @with_seed() def test_convolution_grouping(): for dim in [1, 2, 3]: