Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix BatchNorm backward synchronization #18644

Merged
merged 2 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ then set ``gamma`` to 1 and its gradient to 0.
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_inputs(8)
.set_num_outputs(3)
.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{6, 7}; // moving_mean, moving_var
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,32 @@ def transpose(shape):
assert (layer(x).shape==ceil_out_shape)


@with_seed()
@pytest.mark.parametrize('variable', ['running_var', 'running_mean'])
def test_batchnorm_backward_synchronization(variable):
"""
Tests if synchronization of BatchNorm running variables is done correctly.
If not, the test sometimes fails - depending on the timing.
"""
ctx = mx.test_utils.default_context()

for _ in range(20):
layer = nn.BatchNorm()
layer.initialize(ctx=ctx)
for _ in range(3):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1, 3, 10, 10), ctx=ctx)
with mx.autograd.record():
out = layer(data)
out.backward()

# check if each read give the same value
var1 = getattr(layer, variable).data().asnumpy()
for _ in range(10):
var2 = getattr(layer, variable).data().asnumpy()
if (var1 != var2).any():
raise AssertionError("Two consecutive reads of " + variable + " give different results")


@with_seed()
def test_batchnorm():
layer = nn.BatchNorm(in_channels=10)
Expand Down