Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Group Normalization (#14959)
Browse files Browse the repository at this point in the history
* GroupNorm

* add to amp list

* re-write forward
  • Loading branch information
haojin2 authored and sxjscience committed Jul 19, 2019
1 parent b887c06 commit eec0fb4
Show file tree
Hide file tree
Showing 7 changed files with 706 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/mxnet/contrib/amp/lists/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@
'log_softmax',
'InstanceNorm',
'LayerNorm',
'GroupNorm',
'L2Normalization',
'LRN',
'SoftmaxActivation',
Expand Down
91 changes: 90 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'Flatten', 'Lambda', 'HybridLambda']
import warnings
import numpy as np

Expand Down Expand Up @@ -616,6 +617,94 @@ def __repr__(self):
for k, v in self._kwargs.items()]))


class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2 axis are
`batch` and `channel` respectively:
.. math::
x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
Inputs:
- **data**: input tensor with shape (N, C, ...).
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_
Examples
--------
>>> # Input of shape (2, 3, 4)
>>> x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> # Group normalization is calculated with the above formula
>>> layer = GroupNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
return norm_data

def __repr__(self):
s = '{name}({content})'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))


class Lambda(Block):
r"""Wraps an operator or an expression as a Block object.
Expand Down
Loading

0 comments on commit eec0fb4

Please sign in to comment.