Skip to content

Commit

Permalink
instance norm and reflection padding (apache#7938)
Browse files Browse the repository at this point in the history
* instance norm and reflection padding

* r prefix

* indent and space

* fix docs

* change docs

* spacing

* typo

* hybrid forward

* spcaing

* add test for instance norm

* fix typo

* add to __all__

* rm white space

* integer value

* add test

* make line short

* rm white space

* add docs ref

* fix docs

* RFpad2D docs

* read shape from weight

* rm condition
  • Loading branch information
zhanghang1989 authored and szha committed Feb 3, 2018
1 parent 793804d commit 14d1187
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/gluon/nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This document lists the neural network blocks in Gluon:
Activation
Dropout
BatchNorm
InstanceNorm
LeakyReLU
Embedding
Flatten
Expand Down Expand Up @@ -62,6 +63,7 @@ This document lists the neural network blocks in Gluon:
GlobalAvgPool1D
GlobalAvgPool2D
GlobalAvgPool3D
ReflectionPad2D
```


Expand Down
85 changes: 82 additions & 3 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Activation',
'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten',
'Lambda', 'HybridLambda']
'Dropout', 'BatchNorm', 'InstanceNorm', 'LeakyReLU', 'Embedding',
'Flatten', 'Lambda', 'HybridLambda']
import warnings
import numpy as np

Expand Down Expand Up @@ -480,6 +480,86 @@ def __repr__(self):
return self.__class__.__name__


class InstanceNorm(HybridBlock):
r"""
Applies instance normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where (n>2) and normalizes
the input using the following formula:
.. math::
out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta
Parameters
----------
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
in_channels : int, default 0
Number of channels (feature maps) in input data. If not specified,
initialization will be deferred to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`_
Examples
--------
>>> # Input of shape (2,1,2)
>>> x = mx.nd.array([[[ 1.1, 2.2]],
... [[ 3.3, 4.4]]])
>>> # Instance normalization is calculated with the above formula
>>> layer = InstanceNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-0.99998355 0.99998331]]
[[-0.99998319 0.99998361]]]
<NDArray 2x1x2 @cpu(0)>
"""
def __init__(self, epsilon=1e-5, center=True, scale=False,
beta_initializer='zeros', gamma_initializer='ones',
in_channels=0, **kwargs):
super(InstanceNorm, self).__init__(**kwargs)
self._kwargs = {'eps': epsilon}
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, x, gamma, beta):
return F.InstanceNorm(x, gamma, beta,
name='fwd', **self._kwargs)

def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))

class Lambda(Block):
r"""Wraps an operator or an expression as a Block object.
Expand Down Expand Up @@ -526,7 +606,6 @@ def __repr__(self):
class HybridLambda(HybridBlock):
r"""Wraps an operator or an expression as a HybridBlock object.
Parameters
----------
function : str or function
Expand Down
34 changes: 33 additions & 1 deletion python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
'MaxPool1D', 'MaxPool2D', 'MaxPool3D',
'AvgPool1D', 'AvgPool2D', 'AvgPool3D',
'GlobalMaxPool1D', 'GlobalMaxPool2D', 'GlobalMaxPool3D',
'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D']
'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D',
'ReflectionPad2D']

from ..block import HybridBlock
from ... import symbol
Expand Down Expand Up @@ -1007,3 +1008,34 @@ def __init__(self, layout='NCDHW', **kwargs):
assert layout == 'NCDHW', "Only supports NCDHW layout for now"
super(GlobalAvgPool3D, self).__init__(
(1, 1, 1), None, 0, True, True, 'avg', **kwargs)


class ReflectionPad2D(HybridBlock):
"""Pads the input tensor using the reflection of the input boundary.
Parameters
----------
padding: int
An integer padding size
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})`
- Output: :math:`(N, C, H_{out}, W_{out})` where
:math:`H_{out} = H_{in} + 2 * padding
:math:`W_{out} = W_{in} + 2 * padding
Examples
--------
>>> m = nn.ReflectionPad2D(3)
>>> input = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> output = m(input)
"""
def __init__(self, padding=0, **kwargs):
super(ReflectionPad2D, self).__init__(**kwargs)
if isinstance(padding, numeric_types):
padding = (0, 0, 0, 0, padding, padding, padding, padding)
assert(len(padding) == 8)
self._padding = padding

def hybrid_forward(self, F, x):
return F.pad(x, mode='reflect', pad_width=self._padding)
11 changes: 11 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,22 @@ def test_pool():
layer.collect_params().initialize()
assert (layer(x).shape==(2, 2, 4, 4))


def test_batchnorm():
layer = nn.BatchNorm(in_channels=10)
check_layer_forward(layer, (2, 10, 10, 10))


def test_instancenorm():
layer = nn.InstanceNorm(in_channels=10)
check_layer_forward(layer, (2, 10, 10, 10))


def test_reflectionpad():
layer = nn.ReflectionPad2D(3)
check_layer_forward(layer, (2, 3, 24, 24))


def test_reshape():
x = mx.nd.ones((2, 4, 10, 10))
layer = nn.Conv2D(10, 2, in_channels=4)
Expand Down

0 comments on commit 14d1187

Please sign in to comment.