diff --git a/tools/coreml/converter/_layers.py b/tools/coreml/converter/_layers.py index d0113a915571..8f4bc1a8a02c 100644 --- a/tools/coreml/converter/_layers.py +++ b/tools/coreml/converter/_layers.py @@ -472,11 +472,14 @@ def convert_batchnorm(net, node, module, builder): inputs = node['inputs'] - eps = 1e-3 # Default value of eps for MXNet. - use_global_stats = False # Default value of use_global_stats for MXNet. + eps = 1e-3 # Default value of eps for MXNet. + use_global_stats = False # Default value of use_global_stats for MXNet. + fix_gamma = True # Default value of fix_gamma for MXNet. attrs = _get_attrs(node) if 'eps' in attrs: eps = literal_eval(attrs['eps']) + if 'fix_gamma' in attrs: + fix_gamma = literal_eval(attrs['fix_gamma']) args, aux = module.get_params() gamma = args[_get_node_name(net, inputs[1][0])].asnumpy() @@ -484,6 +487,8 @@ def convert_batchnorm(net, node, module, builder): mean = aux[_get_node_name(net, inputs[3][0])].asnumpy() variance = aux[_get_node_name(net, inputs[4][0])].asnumpy() nb_channels = gamma.shape[0] + if fix_gamma: + gamma.fill(1.) builder.add_batchnorm( name=name, channels=nb_channels, diff --git a/tools/coreml/test/test_mxnet_converter.py b/tools/coreml/test/test_mxnet_converter.py index 5d26c5faf754..bc850690a572 100644 --- a/tools/coreml/test/test_mxnet_converter.py +++ b/tools/coreml/test/test_mxnet_converter.py @@ -938,6 +938,40 @@ def test_batch_norm_no_global_stats(self): name='batch_norm_1') self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2) + def test_batch_norm_with_fix_gamma(self): + """ The gamma will always be an array of ones when fix_gamma=True. The values + of gamma may be changed accidentally if there have been fix_gamma=False before + the final trained model. + """ + np.random.seed(1988) + input_shape = (1, 1, 2, 3) + + net = mx.sym.Variable('data') + gamma = mx.sym.Variable('gamma') + beta = mx.sym.Variable('beta') + moving_mean = mx.sym.Variable('moving_mean') + moving_var = mx.sym.Variable('moving_var') + net = mx.symbol.BatchNorm( + data=net, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + fix_gamma=True, + name='batch_norm_1') + self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2) + + np.random.seed(1988) + net = mx.symbol.BatchNorm( + data=net, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + fix_gamma=False, + name='batch_norm_2') + self._test_mxnet_model(net, input_shape=input_shape, mode='random', delta=1e-2) + def test_pre_processing_args(self): np.random.seed(1988) input_shape = (1, 10)