Skip to content

Commit

Permalink
Fix BatchNorm converter for CoreML when fix_gamma=True (apache#13557)
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG authored and haohuw committed Jun 23, 2019
1 parent 92fe0bf commit ab0c913
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tools/coreml/converter/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,23 @@ def convert_batchnorm(net, node, module, builder):
inputs = node['inputs']


eps = 1e-3 # Default value of eps for MXNet.
use_global_stats = False # Default value of use_global_stats for MXNet.
eps = 1e-3 # Default value of eps for MXNet.
use_global_stats = False # Default value of use_global_stats for MXNet.
fix_gamma = True # Default value of fix_gamma for MXNet.
attrs = _get_attrs(node)
if 'eps' in attrs:
eps = literal_eval(attrs['eps'])
if 'fix_gamma' in attrs:
fix_gamma = literal_eval(attrs['fix_gamma'])

args, aux = module.get_params()
gamma = args[_get_node_name(net, inputs[1][0])].asnumpy()
beta = args[_get_node_name(net, inputs[2][0])].asnumpy()
mean = aux[_get_node_name(net, inputs[3][0])].asnumpy()
variance = aux[_get_node_name(net, inputs[4][0])].asnumpy()
nb_channels = gamma.shape[0]
if fix_gamma:
gamma.fill(1.)
builder.add_batchnorm(
name=name,
channels=nb_channels,
Expand Down
34 changes: 34 additions & 0 deletions tools/coreml/test/test_mxnet_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,40 @@ def test_batch_norm_no_global_stats(self):
name='batch_norm_1')
self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2)

def test_batch_norm_with_fix_gamma(self):
""" The gamma will always be an array of ones when fix_gamma=True. The values
of gamma may be changed accidentally if there have been fix_gamma=False before
the final trained model.
"""
np.random.seed(1988)
input_shape = (1, 1, 2, 3)

net = mx.sym.Variable('data')
gamma = mx.sym.Variable('gamma')
beta = mx.sym.Variable('beta')
moving_mean = mx.sym.Variable('moving_mean')
moving_var = mx.sym.Variable('moving_var')
net = mx.symbol.BatchNorm(
data=net,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
name='batch_norm_1')
self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2)

np.random.seed(1988)
net = mx.symbol.BatchNorm(
data=net,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
name='batch_norm_2')
self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2)

def test_pre_processing_args(self):
np.random.seed(1988)
input_shape = (1, 10)
Expand Down

0 comments on commit ab0c913

Please sign in to comment.