Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] class based embedding similarity + tests #3358

Merged
merged 16 commits into from
Sep 11, 2020
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -13,6 +13,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/))

- Added `EmbeddingSimilarity` metric:
* functional interface ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349))
* class based interface + tests ([#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
@@ -134,7 +138,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed adding val step argument to metrics ([#2986](https://github.com/PyTorchLightning/pytorch-lightning/pull/2986))
- Fixed an issue that caused `Trainer.test()` to stall in ddp mode ([#2997](https://github.com/PyTorchLightning/pytorch-lightning/pull/2997))
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))

4 changes: 2 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -162,7 +162,8 @@ EmbeddingSimilarity
^^^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.self_supervised.EmbeddingSimilarity

:noindex:

F1
^^

@@ -634,4 +635,3 @@ MeanTweedieDeviance (sk)

.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
:noindex:

33 changes: 14 additions & 19 deletions pytorch_lightning/metrics/self_supervised.py
Original file line number Diff line number Diff line change
@@ -25,35 +25,30 @@

class EmbeddingSimilarity(TensorMetric):
"""
Computes similarity between embeddings

Example:
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])

"""

def __init__(self,
similarity: str = 'cosine',
zero_diagonal: bool = True,
reduction: str = 'mean',
reduce_group: Any = None):
def __init__(
self,
similarity: str = 'cosine',
zero_diagonal: bool = True,
reduction: str = 'mean',
reduce_group: Any = None
):
"""
Computes similarity between embeddings

Example:

>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])

Args:
similarity: 'dot' or 'cosine'
reduction: 'none', 'sum', 'mean' (all along dim -1)
zero_diagonal: if True, the diagonals are set to zero
reduce_group: the process group to reduce metric results from DDP

Return:
A square matrix (batch, batch) with the similarity scores between all elements
If sum or mean are used, then returns (b, 1) with the reduced value for each row
"""
super().__init__(name='embedding_similarity',
reduce_group=reduce_group)