Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add New API nn.HingeEmbeddingLoss #37540

Merged
merged 36 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5e643ca
add hinge_embedding_loss
S-HuaBomb Nov 24, 2021
6edb279
fix test_API
S-HuaBomb Nov 25, 2021
7e9207e
test_API succeed
S-HuaBomb Nov 25, 2021
89da508
add English doc
S-HuaBomb Nov 25, 2021
769a4b8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Nov 25, 2021
5ea300d
fixed using of expired fluid api
S-HuaBomb Nov 25, 2021
98d687e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Nov 25, 2021
c6bd8d4
fix doc
S-HuaBomb Nov 25, 2021
341ba5b
fix doc and rm python/paddle/fluid/layers/loss.py
S-HuaBomb Nov 25, 2021
a20b2de
get raw python/paddle/fluid/layers/loss.py back
S-HuaBomb Nov 25, 2021
b83a1e2
fix Examples bug in English doc
S-HuaBomb Nov 25, 2021
e077423
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Nov 25, 2021
2988c76
unique -> flatten
S-HuaBomb Nov 25, 2021
04cf985
fix api code
S-HuaBomb Nov 26, 2021
a3bfd3e
fix English doc
S-HuaBomb Nov 26, 2021
d7d892c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Nov 26, 2021
562aa9c
fix functional loss English doc
S-HuaBomb Nov 26, 2021
c658354
fix Example doc
S-HuaBomb Nov 26, 2021
b47d411
.numpy() -> paddle.unique()
S-HuaBomb Dec 2, 2021
de677a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 2, 2021
ebbe89e
fix unique
S-HuaBomb Dec 2, 2021
3231d91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 2, 2021
aa9f9c6
fix label_item_set
S-HuaBomb Dec 5, 2021
22a25dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 5, 2021
c0b31b3
modified judgment equation
S-HuaBomb Dec 5, 2021
f7c49b7
Got a beautiful loss equation
S-HuaBomb Dec 6, 2021
aca041e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 6, 2021
3e94ec4
use paddle.to_tensor
S-HuaBomb Dec 7, 2021
319d085
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 7, 2021
b613766
fix loss and add static check
S-HuaBomb Dec 7, 2021
0e323aa
fix loss and add static check
S-HuaBomb Dec 7, 2021
146f809
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 7, 2021
9eefb38
fix conflicts
S-HuaBomb Dec 7, 2021
a6fcc5c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 9, 2021
f92f098
delta -> margin
S-HuaBomb Dec 14, 2021
0391ff7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
S-HuaBomb Dec 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions python/paddle/fluid/tests/unittests/test_hinge_embedding_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import paddle
import numpy as np
import unittest

np.random.seed(42)


class TestFunctionalHingeEmbeddingLoss(unittest.TestCase):
def setUp(self):
self.delta = 1.0
self.shape = (10, 10, 5)
self.input_np = np.random.random(size=self.shape).astype(np.float32)
# get label elem in {1., -1.}
self.label_np = 2 * np.random.randint(0, 2, size=self.shape) - 1.
# get wrong label elem not in {1., -1.}
self.wrong_label = paddle.randint(-3, 3, shape=self.shape)

def run_dynamic_check(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在可以支持静态图了,就也加上静态图的测试吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,

Done thanks.

input = paddle.to_tensor(self.input_np)
label = paddle.to_tensor(self.label_np, dtype=paddle.float32)
dy_result = paddle.nn.functional.hinge_embedding_loss(input, label)
expected = np.mean(
np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()), input.numpy()))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

dy_result = paddle.nn.functional.hinge_embedding_loss(
input, label, reduction='sum')
expected = np.sum(
np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()), input.numpy()))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

dy_result = paddle.nn.functional.hinge_embedding_loss(
input, label, reduction='none')
expected = np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()),
input.numpy())
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, self.shape)

def test_cpu(self):
paddle.disable_static(place=paddle.CPUPlace())
self.run_dynamic_check()

def test_gpu(self):
if not paddle.is_compiled_with_cuda():
return

paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_dynamic_check()

# test case the raise message
def test_reduce_errors(self):
def test_value_error():
loss = paddle.nn.functional.hinge_embedding_loss(
self.input_np, self.label_np, reduction='reduce_mean')

self.assertRaises(ValueError, test_value_error)


class TestClassHingeEmbeddingLoss(unittest.TestCase):
def setUp(self):
self.delta = 1.0
self.shape = (10, 10, 5)
self.input_np = np.random.random(size=self.shape).astype(np.float32)
# get label elem in {1., -1.}
self.label_np = 2 * np.random.randint(0, 2, size=self.shape) - 1.
# get wrong label elem not in {1., -1.}
self.wrong_label = paddle.randint(-3, 3, shape=self.shape)

def run_dynamic_check(self):
input = paddle.to_tensor(self.input_np)
label = paddle.to_tensor(self.label_np, dtype=paddle.float32)
hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss()
dy_result = hinge_embedding_loss(input, label)
expected = np.mean(
np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()), input.numpy()))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss(
reduction='sum')
dy_result = hinge_embedding_loss(input, label)
expected = np.sum(
np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()), input.numpy()))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss(
reduction='none')
dy_result = hinge_embedding_loss(input, label)
expected = np.where(label.numpy() == -1.,
np.maximum(0., self.delta - input.numpy()),
input.numpy())
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, self.shape)

def test_cpu(self):
paddle.disable_static(place=paddle.CPUPlace())
self.run_dynamic_check()

def test_gpu(self):
if not paddle.is_compiled_with_cuda():
return

paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_dynamic_check()

# test case the raise message
def test_reduce_errors(self):
def test_value_error():
hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss(
reduction='reduce_mean')
loss = hinge_embedding_loss(self.input_np, self.label_np)

self.assertRaises(ValueError, test_value_error)


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
Expand Down Expand Up @@ -295,5 +296,6 @@ def weight_norm(*args):
'ELU',
'ReLU6',
'LayerDict',
'ZeroPad2D'
'ZeroPad2D',
'HingeEmbeddingLoss'
]
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .loss import margin_cross_entropy # noqa: F401
from .loss import square_error_cost # noqa: F401
from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
Expand Down Expand Up @@ -200,6 +201,7 @@
'margin_cross_entropy',
'square_error_cost',
'ctc_loss',
'hinge_embedding_loss',
'affine_grid',
'grid_sample',
'local_response_norm',
Expand Down
106 changes: 101 additions & 5 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,8 +1720,8 @@ def cross_entropy(input,
raise ValueError(
"input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided"\
.format(input.shape[axis], weight.shape[-1]))
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))

ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype)
Expand All @@ -1732,7 +1732,7 @@ def cross_entropy(input,
axis)
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \
+ [axis % valid_label.ndim]
weight_gather = _C_ops.gather_nd(
weight, valid_label.transpose(temp_perm))
Expand Down Expand Up @@ -1834,8 +1834,8 @@ def cross_entropy(input,
else:
if input.shape[axis] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided"\
"weight's class_dimension({}) "
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))

valid_label = paddle.where(label == ignore_index,
Expand Down Expand Up @@ -2051,3 +2051,99 @@ def sigmoid_focal_loss(logit,
loss = paddle.sum(loss, name=name)

return loss


def hinge_embedding_loss(input, label, delta=1.0, reduction='mean', name=None):
r"""
This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
and is typically used for learning nonlinear embeddings or semi-supervised learning.

The loss function for :math:`n`-th sample in the mini-batch is

.. math::
l_n = \begin{cases}
x_n, & \text{if}\; y_n = 1,\\
\max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
\end{cases}

and the total loss functions is

.. math::
\ell(x, y) = \begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
\end{cases}

where :math:`L = \{l_1,\dots,l_N\}^\top`.

Parameters:
input (Tensor): Input tensor, the data type is float32 or float64.
the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64.
The shape of label is the same as the shape of input.
delta (float, optional): Specifies the hyperparameter delta to be used.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认下,这个参数名称用delta而不用margin的原因是什么?
在margin_ranking_loss这个api里是用margin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为公式里用 Δ 所以我就取名为 delta。PyTorch 的同一个 API 也是用 margin,我要改成 margin 吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦改成margin吧,可以跟margin_ranking_loss保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

The value determines how large the input need to be to calculate in
hinge_embedding_loss. When label is -1, Input smaller than delta are minimized with hinge_embedding_loss.
Default = 1.0
reduction (str, optional): Indicate how to average the loss by batch_size.
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Shape:

input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

label: N-D Tensor, same shape as the input. tensor elements should containing 1 or -1, the data type is float32 or float64.

output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

Returns:
Tensor. The tensor variable storing the hinge_embedding_loss of input and label.

Examples:
.. code-block:: python

import paddle
import paddle.nn.functional as F

input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
# label elements in {1., -1.}
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

loss = F.hinge_embedding_loss(input, label, delta=1.0, reduction='none')
print(loss)
# Tensor([[0., -2., 0.],
# [0., -1., 2.],
# [1., 1., 1.]])

loss = F.hinge_embedding_loss(input, label, delta=1.0, reduction='mean')
print(loss)
# Tensor([0.22222222])
"""
S-HuaBomb marked this conversation as resolved.
Show resolved Hide resolved

if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'hinge_embedding_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))

if not paddle.fluid.framework.in_dygraph_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'hinge_embedding_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'hinge_embedding_loss')

loss = paddle.where(label == 1., input, paddle.to_tensor(0.)) + \
paddle.where(label == -1., delta - input, paddle.to_tensor(0.))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(0, delta-input)还是需要的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,粗心了,已修改。

Done, thanks.


if reduction == 'mean':
return paddle.mean(loss, name=name)
elif reduction == 'sum':
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from .loss import MarginRankingLoss # noqa: F401
from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401
S-HuaBomb marked this conversation as resolved.
Show resolved Hide resolved
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
Expand Down
Loading