Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Gluon for GroupNorm and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed May 20, 2019
1 parent 7019f56 commit 2ed60e4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 7 deletions.
91 changes: 90 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda']
'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'Flatten', 'Lambda', 'HybridLambda']
import warnings
import numpy as np

Expand Down Expand Up @@ -616,6 +617,94 @@ def __repr__(self):
for k, v in self._kwargs.items()]))


class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2 axis are
`batch` and `channel` respectively:
.. math::
x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta
Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_
Examples
--------
>>> # Input of shape (2, 3, 4)
>>> x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> # Group normalization is calculated with the above formula
>>> layer = GroupNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
return norm_data

def __repr__(self):
s = '{name}({content})'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))


class Lambda(Block):
r"""Wraps an operator or an expression as a Block object.
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/group_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,14 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_src_shape,
kAddTo, red_dst_shape));
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape,
kAddTo, red_src_shape));
});
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_src_shape, kAddTo,
red_exclude_dst_shape));
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_dst_shape, kAddTo,
red_exclude_src_shape));
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Expand Down Expand Up @@ -264,7 +264,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity, true>(
broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
});
Expand All @@ -285,7 +285,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity, true>(
broadcast::Reduce<mshadow_op::sum, NDim, DType, op::mshadow_op::identity, true>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
});
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,15 @@ def test_layernorm():
check_layer_forward(layer, (2, 10, 10, 10))


@with_seed()
def test_groupnorm():
layer = nn.GroupNorm()
check_layer_forward(layer, (2, 10, 10, 10))
layer = nn.GroupNorm(num_groups=2)
check_layer_forward(layer, (2, 10, 10, 10))
layer = nn.GroupNorm(num_groups=5)
check_layer_forward(layer, (2, 10, 10, 10))

@with_seed()
def test_reflectionpad():
layer = nn.ReflectionPad2D(3)
Expand Down

0 comments on commit 2ed60e4

Please sign in to comment.