diff --git a/docs/api/python/gluon/nn.md b/docs/api/python/gluon/nn.md index 5e2dbe016d61..4515644c6b44 100644 --- a/docs/api/python/gluon/nn.md +++ b/docs/api/python/gluon/nn.md @@ -20,6 +20,7 @@ This document lists the neural network blocks in Gluon: Activation Dropout BatchNorm + InstanceNorm LeakyReLU Embedding Flatten @@ -62,6 +63,7 @@ This document lists the neural network blocks in Gluon: GlobalAvgPool1D GlobalAvgPool2D GlobalAvgPool3D + ReflectionPad2D ``` diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 43b4bdaf1cf0..15175395533c 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -19,8 +19,8 @@ # pylint: disable= arguments-differ """Basic neural network layers.""" __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Activation', - 'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten', - 'Lambda', 'HybridLambda'] + 'Dropout', 'BatchNorm', 'InstanceNorm', 'LeakyReLU', 'Embedding', + 'Flatten', 'Lambda', 'HybridLambda'] import warnings import numpy as np @@ -480,6 +480,86 @@ def __repr__(self): return self.__class__.__name__ +class InstanceNorm(HybridBlock): + r""" + Applies instance normalization to the n-dimensional input array. + This operator takes an n-dimensional input array where (n>2) and normalizes + the input using the following formula: + + .. math:: + + out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta + + Parameters + ---------- + epsilon: float, default 1e-5 + Small float added to variance to avoid dividing by zero. + center: bool, default True + If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: bool, default True + If True, multiply by `gamma`. If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + beta_initializer: str or `Initializer`, default 'zeros' + Initializer for the beta weight. + gamma_initializer: str or `Initializer`, default 'ones' + Initializer for the gamma weight. + in_channels : int, default 0 + Number of channels (feature maps) in input data. If not specified, + initialization will be deferred to the first time `forward` is called + and `in_channels` will be inferred from the shape of input data. + + Inputs: + - **data**: input tensor with arbitrary shape. + + Outputs: + - **out**: output tensor with the same shape as `data`. + + References + ---------- + `Instance Normalization: The Missing Ingredient for Fast Stylization + `_ + + Examples + -------- + >>> # Input of shape (2,1,2) + >>> x = mx.nd.array([[[ 1.1, 2.2]], + ... [[ 3.3, 4.4]]]) + >>> # Instance normalization is calculated with the above formula + >>> layer = InstanceNorm() + >>> layer.initialize(ctx=mx.cpu(0)) + >>> layer(x) + [[[-0.99998355 0.99998331]] + [[-0.99998319 0.99998361]]] + + """ + def __init__(self, epsilon=1e-5, center=True, scale=False, + beta_initializer='zeros', gamma_initializer='ones', + in_channels=0, **kwargs): + super(InstanceNorm, self).__init__(**kwargs) + self._kwargs = {'eps': epsilon} + self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', + shape=(in_channels,), init=gamma_initializer, + allow_deferred_init=True) + self.beta = self.params.get('beta', grad_req='write' if center else 'null', + shape=(in_channels,), init=beta_initializer, + allow_deferred_init=True) + + def hybrid_forward(self, F, x, gamma, beta): + return F.InstanceNorm(x, gamma, beta, + name='fwd', **self._kwargs) + + def __repr__(self): + s = '{name}({content}' + in_channels = self.gamma.shape[0] + s += ', in_channels={0}'.format(in_channels) + s += ')' + return s.format(name=self.__class__.__name__, + content=', '.join(['='.join([k, v.__repr__()]) + for k, v in self._kwargs.items()])) + class Lambda(Block): r"""Wraps an operator or an expression as a Block object. @@ -526,7 +606,6 @@ def __repr__(self): class HybridLambda(HybridBlock): r"""Wraps an operator or an expression as a HybridBlock object. - Parameters ---------- function : str or function diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 822a81cc6b24..4e9ee689bdb1 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -23,7 +23,8 @@ 'MaxPool1D', 'MaxPool2D', 'MaxPool3D', 'AvgPool1D', 'AvgPool2D', 'AvgPool3D', 'GlobalMaxPool1D', 'GlobalMaxPool2D', 'GlobalMaxPool3D', - 'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D'] + 'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D', + 'ReflectionPad2D'] from ..block import HybridBlock from ... import symbol @@ -1007,3 +1008,34 @@ def __init__(self, layout='NCDHW', **kwargs): assert layout == 'NCDHW', "Only supports NCDHW layout for now" super(GlobalAvgPool3D, self).__init__( (1, 1, 1), None, 0, True, True, 'avg', **kwargs) + + +class ReflectionPad2D(HybridBlock): + """Pads the input tensor using the reflection of the input boundary. + + Parameters + ---------- + padding: int + An integer padding size + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + :math:`H_{out} = H_{in} + 2 * padding + :math:`W_{out} = W_{in} + 2 * padding + + Examples + -------- + >>> m = nn.ReflectionPad2D(3) + >>> input = mx.nd.random.normal(shape=(16, 3, 224, 224)) + >>> output = m(input) + """ + def __init__(self, padding=0, **kwargs): + super(ReflectionPad2D, self).__init__(**kwargs) + if isinstance(padding, numeric_types): + padding = (0, 0, 0, 0, padding, padding, padding, padding) + assert(len(padding) == 8) + self._padding = padding + + def hybrid_forward(self, F, x): + return F.pad(x, mode='reflect', pad_width=self._padding) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 80109cf99a69..d239705575bc 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -328,11 +328,22 @@ def test_pool(): layer.collect_params().initialize() assert (layer(x).shape==(2, 2, 4, 4)) + def test_batchnorm(): layer = nn.BatchNorm(in_channels=10) check_layer_forward(layer, (2, 10, 10, 10)) +def test_instancenorm(): + layer = nn.InstanceNorm(in_channels=10) + check_layer_forward(layer, (2, 10, 10, 10)) + + +def test_reflectionpad(): + layer = nn.ReflectionPad2D(3) + check_layer_forward(layer, (2, 3, 24, 24)) + + def test_reshape(): x = mx.nd.ones((2, 4, 10, 10)) layer = nn.Conv2D(10, 2, in_channels=4)