Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additive margin softmax #1131

Merged
merged 6 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Additive Margin SoftMax(AMSoftmax)([#1125](https://github.com/catalyst-team/catalyst/issues/1125))

### Added

- Generalized Mean Pooling(GeM)([#1084](https://github.com/catalyst-team/catalyst/issues/1084))

- Generalized Mean Pooling(GeM) ([#1084](https://github.com/catalyst-team/catalyst/issues/1084))
- Key-value support for CriterionCallback ([#1130](https://github.com/catalyst-team/catalyst/issues/1130))
- Engine configuration through cmd ([#1134](https://github.com/catalyst-team/catalyst/issues/1134))
Expand All @@ -22,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-


### Fixed

Expand Down Expand Up @@ -97,8 +104,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- ([#1002](https://github.com/catalyst-team/catalyst/pull/1002))
- a few docs
- ([#998](https://github.com/catalyst-team/catalyst/pull/998))
- ``reciprocal_rank`` metric
- unified recsys metrics preprocessing
- ``reciprocal_rank`` metric
- unified recsys metrics preprocessing
- ([#1018](https://github.com/catalyst-team/catalyst/pull/1018))
- readme examples for all supported metrics under ``catalyst.metrics``
- ``wrap_metric_fn_with_activation`` for model outputs wrapping with activation
Expand Down Expand Up @@ -146,7 +153,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- ([#1018](https://github.com/catalyst-team/catalyst/pull/1014))
- ClasswiseIouCallback/ClasswiseJaccardCallback as deprecated on (should be refactored in future releases)



### Fixed

Expand All @@ -163,7 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added
- DCG, nDCG metrics ([#881](https://github.com/catalyst-team/catalyst/pull/881))
- MAP calculations [#968](https://github.com/catalyst-team/catalyst/pull/968)
- MAP calculations [#968](https://github.com/catalyst-team/catalyst/pull/968)
- hitrate calculations [#975] (https://github.com/catalyst-team/catalyst/pull/975)
- extra functions for classification metrics ([#966](https://github.com/catalyst-team/catalyst/pull/966))
- `OneOf` and `OneOfV2` batch transforms ([#951](https://github.com/catalyst-team/catalyst/pull/951))
Expand Down Expand Up @@ -197,7 +204,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- MRR metrics calculation ([#886](https://github.com/catalyst-team/catalyst/pull/886))
- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- ArcMargin layer to contrib ([#957](https://github.com/catalyst-team/catalyst/pull/957))
- AdaCos to contrib ([#958](https://github.com/catalyst-team/catalyst/pull/958))
Expand All @@ -218,7 +225,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand All @@ -243,7 +250,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -271,13 +278,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

- autoresume option for Config API ([#907](https://github.com/catalyst-team/catalyst/pull/907))
- a few issues with TF projector ([#917](https://github.com/catalyst-team/catalyst/pull/917))
- batch sampler speed issue ([#921](https://github.com/catalyst-team/catalyst/pull/921))
- batch sampler speed issue ([#921](https://github.com/catalyst-team/catalyst/pull/921))
- add apex key-value optimizer support ([#924](https://github.com/catalyst-team/catalyst/pull/924))
- runtime warning for PyTorch 1.6 ([920](https://github.com/catalyst-team/catalyst/pull/920))
- Apex synbn usage ([920](https://github.com/catalyst-team/catalyst/pull/920))
Expand Down Expand Up @@ -377,7 +384,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -407,7 +414,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
-

### Fixed

Expand Down Expand Up @@ -461,35 +468,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [20.04] - 2020-04-06

### Added
-


### Changed

-
-

### Removed

-
-

### Fixed

-
-


## [YY.MM.R] - YYYY-MM-DD

### Added

-
-

### Changed

-
-

### Removed

-
-

### Fixed

-
-
1 change: 1 addition & 0 deletions catalyst/contrib/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
from torch.nn.modules import *

from catalyst.contrib.nn.modules.amsoftmax import AMSoftmax
from catalyst.contrib.nn.modules.arcface import ArcFace, SubCenterArcFace
from catalyst.contrib.nn.modules.arcmargin import ArcMarginProduct
from catalyst.contrib.nn.modules.common import (
Expand Down
106 changes: 106 additions & 0 deletions catalyst/contrib/nn/modules/amsoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class AMSoftmax(nn.Module):
"""Implementation of
`AMSoftmax: Additive Margin Softmax for Face Verification`_.

.. _AMSoftmax\: Additive Margin Softmax for Face Verification:
https://arxiv.org/pdf/1801.05599.pdf

Args:
in_features: size of each input sample.
out_features: size of each output sample.
s: norm of input feature.
Default: ``64.0``.
m: margin.
Default: ``0.5``.
eps: operation accuracy.
Default: ``1e-6``.

Shape:
- Input: :math:`(batch, H_{in})` where
:math:`H_{in} = in\_features`.
- Output: :math:`(batch, H_{out})` where
:math:`H_{out} = out\_features`.

Example:
>>> layer = AMSoftmax(5, 10, s=1.31, m=0.5)
>>> loss_fn = nn.CrossEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()

"""

def __init__( # noqa: D107
self,
in_features: int,
out_features: int,
s: float = 64.0,
m: float = 0.5,
eps: float = 1e-6,
):
super(AMSoftmax, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.eps = eps

self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
"""Object representation."""
rep = (
"ArcFace("
f"in_features={self.in_features},"
f"out_features={self.out_features},"
f"s={self.s},"
f"m={self.m},"
f"eps={self.eps}"
")"
)
return rep

def forward(self, input: torch.Tensor, target: torch.LongTensor = None) -> torch.Tensor:
"""
Args:
input: input features,
expected shapes ``BxF`` where ``B``
is batch dimension and ``F`` is an
input feature dimension.
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes
(out_features).
"""
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))

if target is None:
return cos_theta

cos_theta = torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)

one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

logits = torch.where(one_hot.bool(), cos_theta - self.m, cos_theta)
logits *= self.s

return logits


__all__ = ["AMSoftmax"]
2 changes: 1 addition & 1 deletion catalyst/contrib/nn/modules/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ArcFace(nn.Module):

Example:
>>> layer = ArcFace(5, 10, s=1.31, m=0.5)
>>> loss_fn = nn.CrosEntropyLoss()
>>> loss_fn = nn.CrossEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding, target)
Expand Down
Loading