From eaad6ba41a2e70b5b63d63b72b873020a1250723 Mon Sep 17 00:00:00 2001 From: Skr Bang Date: Wed, 15 Dec 2021 14:52:26 +0800 Subject: [PATCH] Add New API nn.HingeEmbeddingLoss (#37540) * add hinge_embedding_loss * fix test_API * test_API succeed * add English doc * fixed using of expired fluid api * fix doc * fix doc and rm python/paddle/fluid/layers/loss.py * get raw python/paddle/fluid/layers/loss.py back * fix Examples bug in English doc * unique -> flatten * fix api code * fix English doc * fix functional loss English doc * fix Example doc * .numpy() -> paddle.unique() * fix unique * fix label_item_set * modified judgment equation * Got a beautiful loss equation * use paddle.to_tensor * fix loss and add static check * fix loss and add static check * delta -> margin --- .../unittests/test_hinge_embedding_loss.py | 181 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/loss.py | 107 ++++++++++- python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 95 +++++++++ 6 files changed, 383 insertions(+), 5 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_hinge_embedding_loss.py diff --git a/python/paddle/fluid/tests/unittests/test_hinge_embedding_loss.py b/python/paddle/fluid/tests/unittests/test_hinge_embedding_loss.py new file mode 100644 index 0000000000000..91c1b45cbca41 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hinge_embedding_loss.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import numpy as np +import unittest +from paddle.static import Program, program_guard + +np.random.seed(42) + + +def calc_hinge_embedding_loss(input, label, margin=1.0, reduction='mean'): + result = np.where(label == -1., np.maximum(0., margin - input), 0.) + \ + np.where(label == 1., input, 0.) + if reduction == 'none': + return result + elif reduction == 'sum': + return np.sum(result) + elif reduction == 'mean': + return np.mean(result) + + +class TestFunctionalHingeEmbeddingLoss(unittest.TestCase): + def setUp(self): + self.margin = 1.0 + self.shape = (10, 10, 5) + self.input_np = np.random.random(size=self.shape).astype(np.float64) + # get label elem in {1., -1.} + self.label_np = 2 * np.random.randint(0, 2, size=self.shape) - 1. + + def run_dynamic_check(self, place=paddle.CPUPlace()): + paddle.disable_static(place=place) + input = paddle.to_tensor(self.input_np) + label = paddle.to_tensor(self.label_np, dtype=paddle.float64) + + dy_result = paddle.nn.functional.hinge_embedding_loss(input, label) + expected = calc_hinge_embedding_loss(self.input_np, self.label_np) + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, [1]) + + dy_result = paddle.nn.functional.hinge_embedding_loss( + input, label, reduction='sum') + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction='sum') + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, [1]) + + dy_result = paddle.nn.functional.hinge_embedding_loss( + input, label, reduction='none') + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction='none') + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, self.shape) + + def run_static_check(self, place=paddle.CPUPlace): + paddle.enable_static() + for reduction in ['none', 'mean', 'sum']: + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction=reduction) + with program_guard(Program(), Program()): + input = paddle.static.data( + name="input", shape=self.shape, dtype=paddle.float64) + label = paddle.static.data( + name="label", shape=self.shape, dtype=paddle.float64) + st_result = paddle.nn.functional.hinge_embedding_loss( + input, label, reduction=reduction) + exe = paddle.static.Executor(place) + result_numpy, = exe.run( + feed={"input": self.input_np, + "label": self.label_np}, + fetch_list=[st_result]) + self.assertTrue(np.allclose(result_numpy, expected)) + + def test_cpu(self): + self.run_dynamic_check(place=paddle.CPUPlace()) + self.run_static_check(place=paddle.CPUPlace()) + + def test_gpu(self): + if not paddle.is_compiled_with_cuda(): + return + self.run_dynamic_check(place=paddle.CUDAPlace(0)) + self.run_static_check(place=paddle.CUDAPlace(0)) + + # test case the raise message + def test_reduce_errors(self): + def test_value_error(): + loss = paddle.nn.functional.hinge_embedding_loss( + self.input_np, self.label_np, reduction='reduce_mean') + + self.assertRaises(ValueError, test_value_error) + + +class TestClassHingeEmbeddingLoss(unittest.TestCase): + def setUp(self): + self.margin = 1.0 + self.shape = (10, 10, 5) + self.input_np = np.random.random(size=self.shape).astype(np.float64) + # get label elem in {1., -1.} + self.label_np = 2 * np.random.randint(0, 2, size=self.shape) - 1. + + def run_dynamic_check(self, place=paddle.CPUPlace()): + paddle.disable_static(place=place) + input = paddle.to_tensor(self.input_np) + label = paddle.to_tensor(self.label_np, dtype=paddle.float64) + hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss() + dy_result = hinge_embedding_loss(input, label) + expected = calc_hinge_embedding_loss(self.input_np, self.label_np) + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, [1]) + + hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( + reduction='sum') + dy_result = hinge_embedding_loss(input, label) + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction='sum') + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, [1]) + + hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( + reduction='none') + dy_result = hinge_embedding_loss(input, label) + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction='none') + self.assertTrue(np.allclose(dy_result.numpy(), expected)) + self.assertTrue(dy_result.shape, self.shape) + + def run_static_check(self, place=paddle.CPUPlace): + paddle.enable_static() + for reduction in ['none', 'mean', 'sum']: + expected = calc_hinge_embedding_loss( + self.input_np, self.label_np, reduction=reduction) + with program_guard(Program(), Program()): + input = paddle.static.data( + name="input", shape=self.shape, dtype=paddle.float64) + label = paddle.static.data( + name="label", shape=self.shape, dtype=paddle.float64) + hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( + reduction=reduction) + st_result = hinge_embedding_loss(input, label) + exe = paddle.static.Executor(place) + result_numpy, = exe.run( + feed={"input": self.input_np, + "label": self.label_np}, + fetch_list=[st_result]) + self.assertTrue(np.allclose(result_numpy, expected)) + + def test_cpu(self): + self.run_dynamic_check(place=paddle.CPUPlace()) + self.run_static_check(place=paddle.CPUPlace()) + + def test_gpu(self): + if not paddle.is_compiled_with_cuda(): + return + self.run_dynamic_check(place=paddle.CUDAPlace(0)) + self.run_static_check(place=paddle.CUDAPlace(0)) + + # test case the raise message + def test_reduce_errors(self): + def test_value_error(): + hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( + reduction='reduce_mean') + loss = hinge_embedding_loss(self.input_np, self.label_np) + + self.assertRaises(ValueError, test_value_error) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 3afd2b56569a4..e1c40e8d0d3d7 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -101,6 +101,7 @@ from .layer.loss import MarginRankingLoss # noqa: F401 from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 +from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 @@ -297,4 +298,5 @@ def weight_norm(*args): 'LayerDict', 'ZeroPad2D', 'MaxUnPool2D', + 'HingeEmbeddingLoss', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 3e3bd6397c072..a504c1ee6a4fe 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -86,6 +86,7 @@ from .loss import margin_cross_entropy # noqa: F401 from .loss import square_error_cost # noqa: F401 from .loss import ctc_loss # noqa: F401 +from .loss import hinge_embedding_loss # noqa: F401 from .norm import batch_norm # noqa: F401 from .norm import instance_norm # noqa: F401 from .norm import layer_norm # noqa: F401 @@ -200,6 +201,7 @@ 'margin_cross_entropy', 'square_error_cost', 'ctc_loss', + 'hinge_embedding_loss', 'affine_grid', 'grid_sample', 'local_response_norm', diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2332c14b2d97a..328eb07b5e960 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1720,8 +1720,8 @@ def cross_entropy(input, raise ValueError( "input's class_dimension({}) must equal to " "weight's class_dimension({}) " - "when weight is provided"\ - .format(input.shape[axis], weight.shape[-1])) + "when weight is provided" \ + .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) @@ -1732,7 +1732,7 @@ def cross_entropy(input, axis) if axis != -1 and axis != valid_label.ndim - 1: temp_perm = list(range(axis % valid_label.ndim)) \ - + list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \ + + list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \ + [axis % valid_label.ndim] weight_gather = _C_ops.gather_nd( weight, valid_label.transpose(temp_perm)) @@ -1834,8 +1834,8 @@ def cross_entropy(input, else: if input.shape[axis] != weight.shape[-1]: raise ValueError("input's class_dimension({}) must equal to " - "weight's class_dimension({}) " - "when weight is provided"\ + "weight's class_dimension({}) " + "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) valid_label = paddle.where(label == ignore_index, @@ -2051,3 +2051,100 @@ def sigmoid_focal_loss(logit, loss = paddle.sum(loss, name=name) return loss + + +def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): + r""" + This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1). + This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`, + and is typically used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Parameters: + input (Tensor): Input tensor, the data type is float32 or float64. + the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. + label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. + The shape of label is the same as the shape of input. + margin (float, optional): Specifies the hyperparameter margin to be used. + The value determines how large the input need to be to calculate in + hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss. + Default = 1.0 + reduction (str, optional): Indicate how to average the loss by batch_size. + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + + input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements. + + label: N-D Tensor, same shape as the input. tensor elements should containing 1 or -1, the data type is float32 or float64. + + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + + Returns: + Tensor. The tensor variable storing the hinge_embedding_loss of input and label. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + # label elements in {1., -1.} + label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) + + loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='none') + print(loss) + # Tensor([[0., -2., 0.], + # [0., -1., 2.], + # [1., 1., 1.]]) + + loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='mean') + print(loss) + # Tensor([0.22222222]) + """ + + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'hinge_embedding_loss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + + if not paddle.fluid.framework.in_dygraph_mode(): + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'hinge_embedding_loss') + check_variable_and_dtype(label, 'label', ['float32', 'float64'], + 'hinge_embedding_loss') + + zero_ = paddle.zeros([1], dtype=input.dtype) + loss = paddle.where(label == 1., input, zero_) + \ + paddle.where(label == -1., paddle.nn.functional.relu(margin - input), zero_) + + if reduction == 'mean': + return paddle.mean(loss, name=name) + elif reduction == 'sum': + return paddle.sum(loss, name=name) + elif reduction == 'none': + return loss diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index a65f9912d5939..a78269a4cd4d7 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -73,6 +73,7 @@ from .loss import MarginRankingLoss # noqa: F401 from .loss import CTCLoss # noqa: F401 from .loss import SmoothL1Loss # noqa: F401 +from .loss import HingeEmbeddingLoss # noqa: F401 from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3ac0d675fb72c..9da41f26969c8 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1203,3 +1203,98 @@ def forward(self, input, label): reduction=self.reduction, delta=self.delta, name=self.name) + + +class HingeEmbeddingLoss(Layer): + r""" + This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1). + This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`, + and is typically used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Parameters: + + margin (float, optional): Specifies the hyperparameter margin to be used. + The value determines how large the input need to be to calculate in + hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss. + Default = 1.0 + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Call Parameters: + + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. + + label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input. + + Shape: + + input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements. + + label: N-D Tensor, same shape as the input. + + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + + Returns: + + Tensor, The tensor variable storing the hinge_embedding_loss of input and label. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + # label elements in {1., -1.} + label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) + + hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='none') + loss = hinge_embedding_loss(input, label) + print(loss) + # Tensor([[0., -2., 0.], + # [0., -1., 2.], + # [1., 1., 1.]]) + + hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='mean') + loss = hinge_embedding_loss(input, label) + print(loss) + # Tensor([0.22222222]) + """ + + def __init__(self, margin=1.0, reduction="mean", name=None): + super(HingeEmbeddingLoss, self).__init__() + self.margin = margin + self.reduction = reduction + self.name = name + + def forward(self, input, label): + return F.hinge_embedding_loss( + input, + label, + reduction=self.reduction, + margin=self.margin, + name=self.name)