diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 816bb263cec2c..d18c8e2597444 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -635,5 +635,103 @@ def test_dygraph(self): self.assertTrue(np.allclose(result.numpy(), result_np)) +class TestAlphaDropoutFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[40, 40], dtype="float32") + res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.) + res2 = paddle.nn.functional.alpha_dropout( + x=input, p=0., training=False) + + in_np = np.random.random([40, 40]).astype("float32") + res_np = in_np + + exe = fluid.Executor(place) + res_list = [res1, res2] + for res in res_list: + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + in_np = np.random.random([40, 40]).astype("float32") + res_np = in_np + input = fluid.dygraph.to_variable(in_np) + + res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.) + res2 = paddle.nn.functional.alpha_dropout( + x=input, p=0., training=False) + + res_list = [res1, res2] + for res in res_list: + self.assertTrue(np.allclose(res.numpy(), res_np)) + + +class TestAlphaDropoutFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of dropout must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.alpha_dropout(x1, p=0.5) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of dropout must be float32 or float64 + xr = fluid.data(name='xr', shape=[3, 4, 5, 6], dtype="int32") + paddle.nn.functional.alpha_dropout(xr) + + self.assertRaises(TypeError, test_dtype) + + def test_pdtype(): + # p should be int or float + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.alpha_dropout(x2, p='0.5') + + self.assertRaises(TypeError, test_pdtype) + + def test_pvalue(): + # p should be 0.<=p<=1. + x2 = fluid.data(name='x2', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.alpha_dropout(x2, p=1.2) + + self.assertRaises(ValueError, test_pvalue) + + +class TestAlphaDropoutCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([40, 40]).astype("float32") + result_np = input_np + input = fluid.dygraph.to_variable(input_np) + m = paddle.nn.AlphaDropout(p=0.) + m.eval() + result = m(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index ac04dcaa5a070..131231ade67a7 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -91,6 +91,7 @@ from .layer.common import Dropout #DEFINE_ALIAS from .layer.common import Dropout2D #DEFINE_ALIAS from .layer.common import Dropout3D #DEFINE_ALIAS +from .layer.common import AlphaDropout #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .layer.conv import Conv1d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 4d72a373b7ba5..c1fcc230c1c8d 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -57,6 +57,7 @@ from .common import dropout #DEFINE_ALIAS from .common import dropout2d #DEFINE_ALIAS from .common import dropout3d #DEFINE_ALIAS +from .common import alpha_dropout #DEFINE_ALIAS # from .common import embedding #DEFINE_ALIAS # from .common import fc #DEFINE_ALIAS from .common import label_smooth #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 8091549c08b6d..bf404d54b7d15 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -40,6 +40,7 @@ 'dropout', 'dropout2d', 'dropout3d', + 'alpha_dropout', # 'embedding', # 'fc', 'label_smooth', @@ -476,7 +477,6 @@ def dropout(x, p (float | int): Probability of setting units to zero. Default 0.5. axis (int | list): The axis along which the dropout is performed. Default None. training (bool): A flag indicating whether it is in train phrase or not. Default True. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'] 1. upscale_in_train(default), upscale the output at training time @@ -488,6 +488,7 @@ def dropout(x, - train: out = input * mask - inference: out = input * (1.0 - dropout_prob) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: A Tensor representing the dropout, has same shape and data type as `x` . @@ -549,7 +550,7 @@ def dropout(x, [4 0 6]] (3) What about ``axis=[0, 1]`` ? This means the dropout is performed in all axes of x, which is the same case as default setting ``axis=None`` . - (4) You may note that logically `axis=None` means the dropout is performed in no axis of x, + (4) You may note that logically `axis=None` means the dropout is performed in none axis of x, We generate mask with the shape 1*1. Whole input is randomly selected or dropped. For example, we may get such mask: [[0]] @@ -563,8 +564,7 @@ def dropout(x, When x is a 4d tensor with shape `NCHW`, we can set ``axis=[0,1]`` and the dropout will be performed in channel `N` and `C`, `H` and `W` is tied, i.e. paddle.nn.dropout(x, p, axis=[0,1]) - This is something we called dropout2d. Please refer to ``paddle.nn.functional.dropout2d`` - for more details. + Please refer to ``paddle.nn.functional.dropout2d`` for more details. Similarly, when x is a 5d tensor with shape `NCDHW`, we can set ``axis=[0,1]`` to perform dropout3d. Please refer to ``paddle.nn.functional.dropout3d`` for more details. @@ -795,6 +795,80 @@ def dropout3d(x, p=0.5, training=True, data_format='NCDHW', name=None): name=name) +def alpha_dropout(x, p=0.5, training=True, name=None): + """ + Alpha Dropout is a type of Dropout that maintains the self-normalizing property. + For an input with zero mean and unit standard deviation, the output of Alpha Dropout + maintains the original mean and standard deviation of the input. + Alpha Dropout fits well to SELU activate function by randomly setting activations to the negative saturation value. + + Args: + x (Tensor): The input tensor. The data type is float32 or float64. + p (float | int): Probability of setting units to zero. Default 0.5. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the dropout, has same shape and data type as `x`. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.array([[-1, 1], [-1, 1]]).astype('float32') + x = paddle.to_tensor(x) + y_train = paddle.nn.functional.alpha_dropout(x, 0.5) + y_test = paddle.nn.functional.alpha_dropout(x, 0.5, training=False) + print(x.numpy()) + print(y_train.numpy()) + # [[-0.10721093, 1.6655989 ], [-0.7791938, -0.7791938]] (randomly) + print(y_test.numpy()) + """ + if not isinstance(p, (float, int)): + raise TypeError("p argument should be a float or int") + if p < 0 or p > 1: + raise ValueError("p argument should between 0 and 1") + + if not in_dygraph_mode(): + check_variable_and_dtype(x, 'x', ['float32', 'float64'], + 'alpha_dropout') + + if training: + #get transformation params + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + a = ((1 - p) * (1 + p * alpha_p**2))**-0.5 + b = -a * alpha_p * p + + dtype = x.dtype + input_shape = x.shape + + #get mask + random_tensor = layers.uniform_random( + input_shape, dtype='float32', min=0., max=1.0) + p = layers.fill_constant(shape=[1], dtype='float32', value=p) + keep_mask = layers.greater_equal(random_tensor, p) + keep_mask = layers.cast(keep_mask, dtype) + drop_mask = layers.elementwise_sub( + layers.fill_constant( + shape=input_shape, dtype=dtype, value=1.), + keep_mask) + + #apply mask + b = layers.fill_constant(shape=[1], dtype=dtype, value=b) + y = layers.elementwise_add( + paddle.multiply(x, keep_mask), + layers.scale( + drop_mask, scale=alpha_p)) + res = layers.elementwise_add(layers.scale(y, scale=a), b, name=name) + return res + else: # test + return x + + def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): """ Pad tensor according to 'pad' and 'mode'. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 859b2bf296ebb..70c1f754d91ec 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -55,6 +55,7 @@ from .common import Dropout #DEFINE_ALIAS from .common import Dropout2D #DEFINE_ALIAS from .common import Dropout3D #DEFINE_ALIAS +from .common import AlphaDropout #DEFINE_ALIAS from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .conv import Conv1d #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 3034880533e4a..e0d751eef42a9 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -23,25 +23,11 @@ from ...fluid.framework import _dygraph_tracer __all__ = [ - 'BilinearTensorProduct', - 'Pool2D', - 'Embedding', - 'Linear', - 'UpSample', - 'Pad2D', - 'ReflectionPad1d', - 'ReplicationPad1d', - 'ConstantPad1d', - 'ReflectionPad2d', - 'ReplicationPad2d', - 'ConstantPad2d', - 'ZeroPad2d', - 'ConstantPad3d', - 'ReplicationPad3d', - 'CosineSimilarity', - 'Dropout', - 'Dropout2D', - 'Dropout3D', + 'BilinearTensorProduct', 'Pool2D', 'Embedding', 'Linear', 'UpSample', + 'Pad2D', 'ReflectionPad1d', 'ReplicationPad1d', 'ConstantPad1d', + 'ReflectionPad2d', 'ReplicationPad2d', 'ConstantPad2d', 'ZeroPad2d', + 'ConstantPad3d', 'ReplicationPad3d', 'CosineSimilarity', 'Dropout', + 'Dropout2D', 'Dropout3D', 'AlphaDropout' ] @@ -361,12 +347,12 @@ class Dropout(layers.Layer): according to the given dropout probability. See ``paddle.nn.functional.dropout`` for more details. - In dygraph mode, please use ``eval()`` to indicate whether it is in test phrase or not. + + In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled. Parameters: p (float | int): Probability of setting units to zero. Default: 0.5 axis (int | list): The axis along which the dropout is performed. Default None. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] 1. upscale_in_train(default), upscale the output at training time @@ -378,6 +364,7 @@ class Dropout(layers.Layer): - train: out = input * mask - inference: out = input * (1.0 - p) + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Shape: - input: N-D tensor. @@ -404,7 +391,6 @@ def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None): super(Dropout, self).__init__() self.p = p - self.training = _dygraph_tracer()._train_mode self.axis = axis self.mode = mode self.name = name @@ -430,7 +416,8 @@ class Dropout2D(layers.Layer): See ``paddle.nn.functional.dropout2d`` for more details. - Please use ``eval()`` to indicate whether it is in test phrase or not. + In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled. + Parameters: p (float, optional): Probability of setting units to zero. Default: 0.5 data_format (str, optional): Specify the data format of the input, and the data format of the output @@ -487,7 +474,8 @@ class Dropout3D(layers.Layer): See ``paddle.nn.functional.dropout3d`` for more details. - Please use ``eval()`` to indicate whether it is in test phrase or not. + In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled. + Parameters: p (float | int): Probability of setting units to zero. Default: 0.5 data_format (str, optional): Specify the data format of the input, and the data format of the output @@ -521,7 +509,6 @@ def __init__(self, p=0.5, data_format='NCDHW', name=None): super(Dropout3D, self).__init__() self.p = p - self.training = _dygraph_tracer()._train_mode self.data_format = data_format self.name = name @@ -535,6 +522,55 @@ def forward(self, input): return out +class AlphaDropout(layers.Layer): + """ + Alpha Dropout is a type of Dropout that maintains the self-normalizing property. For an input with + zero mean and unit standard deviation, the output of Alpha Dropout maintains the original mean and + standard deviation of the input. Alpha Dropout fits well to SELU activate function by randomly setting + activations to the negative saturation value. + + For more information, please refer to: + `Self-Normalizing Neural Networks `_ + + In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled. + + Parameters: + p (float | int): Probability of setting units to zero. Default: 0.5 + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: N-D tensor. + - output: N-D tensor, the same shape as input. + + Examples: + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + x = np.array([[-1, 1], [-1, 1]]).astype('float32') + x = paddle.to_tensor(x) + m = paddle.nn.AlphaDropout(p=0.5) + y_train = m(x) + m.eval() # switch the model to test phase + y_test = m(x) + print(x.numpy()) + print(y_train.numpy()) + # [[-0.10721093, 1.6655989 ], [-0.7791938, -0.7791938]] (randomly) + print(y_test.numpy()) + """ + + def __init__(self, p=0.5, name=None): + super(AlphaDropout, self).__init__() + self.p = p + self.name = name + + def forward(self, input): + out = F.alpha_dropout( + input, p=self.p, training=self.training, name=self.name) + return out + + class ReflectionPad1d(layers.Layer): """ This interface is used to construct a callable object of the ``ReflectionPad1d`` class.