Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Group Normalization #14959

Merged
merged 3 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/contrib/amp/lists/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@
'log_softmax',
'InstanceNorm',
'LayerNorm',
'GroupNorm',
'L2Normalization',
'LRN',
'SoftmaxActivation',
Expand Down
91 changes: 90 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'Flatten', 'Lambda', 'HybridLambda']
import warnings
import numpy as np

Expand Down Expand Up @@ -616,6 +617,94 @@ def __repr__(self):
for k, v in self._kwargs.items()]))


class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2 axis are
`batch` and `channel` respectively:

.. math::

x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta

Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.


Inputs:
- **data**: input tensor with shape (N, C, ...).

Outputs:
- **out**: output tensor with the same shape as `data`.

References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_

Examples
--------
>>> # Input of shape (2, 3, 4)
>>> x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> # Group normalization is calculated with the above formula
>>> layer = GroupNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
return norm_data

def __repr__(self):
s = '{name}({content})'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))


class Lambda(Block):
r"""Wraps an operator or an expression as a Block object.

Expand Down
Loading