From e26b0e2f92ce8c2f1aa921833433f4758d690e67 Mon Sep 17 00:00:00 2001 From: Jordan Stomps Date: Wed, 12 Jul 2023 12:13:57 -0400 Subject: [PATCH 1/8] adding details to NTXentLoss documentation --- docs/losses.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/losses.md b/docs/losses.md index 2af1f905..641da746 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -787,6 +787,13 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf){target=_blank} - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank} - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank} + +In the equation below, loss is computed for each positive pair, `k_+`, in a batch, normalized by all pairs in the batch, `k_i in K`. +For each `embeddings` with `labels` and `ref_emb` with `ref_labels`, positive pair `(embeddings[i], ref_emb[j])` are defined when `labels[i] == ref_labels[j]`. +When `embeddings` and `ref_emb` are augmented versions of each other (e.g. SimCLR), `labels[i] == ref_labels[i]` (see [SelfSupervisedLoss](losses.md#selfsupervisedloss)). +Note that multiple positive pairs can exist if the same label is present multiple times in `labels` and/or `ref_labels`. + +Instead of passing labels (`NTXentLoss(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)`), `indices_tuple` could be passed (see [`pytorch_metric_learning.utils.loss_and_miner_utils.get_all_pairs_indices](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/utils/loss_and_miner_utils.py)). ```python losses.NTXentLoss(temperature=0.07, **kwargs) ``` @@ -799,6 +806,16 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **temperature**: This is tau in the above equation. The MoCo paper uses 0.07, while SimCLR uses 0.5. +**Other info:** + +For example, consider `labels = ref_labels = [0, 0, 1, 2]`. Two losses will be computed: + +* Positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. + +* Positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. + +Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. + **Default distance**: - [```CosineSimilarity()```](distances.md#cosinesimilarity) From fac0fe4afafc0128ad7dc587f8497a45284a8e8a Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Thu, 20 Jul 2023 21:40:07 -0400 Subject: [PATCH 2/8] Minor rewording and reorganization of NTXentLoss docs --- docs/losses.md | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/docs/losses.md b/docs/losses.md index 641da746..af3cd76e 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -788,12 +788,26 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank} - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank} -In the equation below, loss is computed for each positive pair, `k_+`, in a batch, normalized by all pairs in the batch, `k_i in K`. -For each `embeddings` with `labels` and `ref_emb` with `ref_labels`, positive pair `(embeddings[i], ref_emb[j])` are defined when `labels[i] == ref_labels[j]`. -When `embeddings` and `ref_emb` are augmented versions of each other (e.g. SimCLR), `labels[i] == ref_labels[i]` (see [SelfSupervisedLoss](losses.md#selfsupervisedloss)). -Note that multiple positive pairs can exist if the same label is present multiple times in `labels` and/or `ref_labels`. +??? "How exactly is the NTXentLoss computed?" + + In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by all positive and negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). + + - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor. + + Given `embeddings` with corresponding `labels`, positive pairs `(embeddings[i], embeddings[j])` are defined when `labels[i] == labels[j]`. Now let's look at an example loss calculation: + + Consider `labels = [0, 0, 1, 2]`. Two losses will be computed: + + * A positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. + + * A positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. + + Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. + + Note that an anchor can belong to multiple positive pairs if its label is present multiple times in `labels`. + + Are you trying to use `NTXentLoss` for self-supervised learning? Specifically, do you have two sets of embeddings which are derived from data that are augmented versions of each other? If so, you can skip the step of creating the `labels` array, by wrapping `NTXentLoss` with [`SelfSupervisedLoss`](losses.md#selfsupervisedloss). -Instead of passing labels (`NTXentLoss(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)`), `indices_tuple` could be passed (see [`pytorch_metric_learning.utils.loss_and_miner_utils.get_all_pairs_indices](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/utils/loss_and_miner_utils.py)). ```python losses.NTXentLoss(temperature=0.07, **kwargs) ``` @@ -806,16 +820,6 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **temperature**: This is tau in the above equation. The MoCo paper uses 0.07, while SimCLR uses 0.5. -**Other info:** - -For example, consider `labels = ref_labels = [0, 0, 1, 2]`. Two losses will be computed: - -* Positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. - -* Positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. - -Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. - **Default distance**: - [```CosineSimilarity()```](distances.md#cosinesimilarity) From 36cfd0d9a1a5dea2b22d3b913e96d4281c00a6bb Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Thu, 20 Jul 2023 22:10:38 -0400 Subject: [PATCH 3/8] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f8eb0c7b..728a42d2 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,7 @@ Thanks to the contributors who made pull requests! | [layumi](https://github.com/layumi) | [InstanceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) | | [NoTody](https://github.com/NoTody) | Helped add `ref_emb` and `ref_labels` to the distributed wrappers. | | [ElisonSherton](https://github.com/ElisonSherton) | Fixed an edge case in ArcFaceLoss. | +| [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss | | [z1w](https://github.com/z1w) | | | [thinline72](https://github.com/thinline72) | | | [tpanum](https://github.com/tpanum) | | From 8e843863d00014c1b5294a1dd1c245118b74e1dc Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Fri, 21 Jul 2023 15:43:33 -0400 Subject: [PATCH 4/8] minor doc correction --- docs/losses.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/losses.md b/docs/losses.md index af3cd76e..05312656 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -790,7 +790,7 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse ??? "How exactly is the NTXentLoss computed?" - In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by all positive and negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). + In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by itself and all negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor. From cf82af53dc91ffb54e3bd6c055f78e8920506685 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Tue, 25 Jul 2023 10:59:12 -0400 Subject: [PATCH 5/8] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 728a42d2..846dae25 100644 --- a/README.md +++ b/README.md @@ -18,16 +18,16 @@ ## News +**July 25**: v2.3.0 +- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) +- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). + **June 18**: v2.2.0 - Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss). - Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss). - See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0). - Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). -**April 5**: v2.1.0 -- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) -- Thanks you [interestingzhuo](https://github.com/interestingzhuo). - ## Documentation - [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/) @@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | -|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
-[HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | From ac607007dc62666f9de850cd5b8e5694ff0da1c2 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Tue, 25 Jul 2023 11:00:08 -0400 Subject: [PATCH 6/8] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 846dae25..9c29da8e 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | -|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
-[HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | @@ -260,6 +260,7 @@ Thanks to the contributors who made pull requests! | [michaeldeyzel](https://github.com/michaeldeyzel) | | | [HSinger04](https://github.com/HSinger04) | | | [rheum](https://github.com/rheum) | | +| [bot66](https://github.com/bot66) | | From 25f800f0bfdd9aab171795dc2251bdc18a431c0d Mon Sep 17 00:00:00 2001 From: Dmitry Nikitko Date: Tue, 12 Sep 2023 12:00:29 +0400 Subject: [PATCH 7/8] Fix PNP loss to make it work with negatives without related positive classes --- src/pytorch_metric_learning/losses/pnp_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_metric_learning/losses/pnp_loss.py b/src/pytorch_metric_learning/losses/pnp_loss.py index 50996e9f..8c5085b1 100644 --- a/src/pytorch_metric_learning/losses/pnp_loss.py +++ b/src/pytorch_metric_learning/losses/pnp_loss.py @@ -68,7 +68,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): else: raise Exception(f"variant <{self.variant}> not available!") - loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1) + loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1) loss = torch.sum(loss) / N if self.variant == "Dq": loss = 1 - loss From 5ba07ba5c7a26c8902f475ad3c74932ec2c29436 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Sat, 11 Nov 2023 19:00:51 +0000 Subject: [PATCH 8/8] Added a test and fixed the denominator --- src/pytorch_metric_learning/losses/pnp_loss.py | 2 +- tests/losses/test_pnp_loss.py | 8 ++++++++ tests/reducers/test_setting_reducers.py | 4 ++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/pytorch_metric_learning/losses/pnp_loss.py b/src/pytorch_metric_learning/losses/pnp_loss.py index 8c5085b1..107244ad 100644 --- a/src/pytorch_metric_learning/losses/pnp_loss.py +++ b/src/pytorch_metric_learning/losses/pnp_loss.py @@ -69,7 +69,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): raise Exception(f"variant <{self.variant}> not available!") loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1) - loss = torch.sum(loss) / N + loss = torch.sum(loss) / torch.sum(safe_N) if self.variant == "Dq": loss = 1 - loss diff --git a/tests/losses/test_pnp_loss.py b/tests/losses/test_pnp_loss.py index 401cb6dc..6a1ad3ea 100644 --- a/tests/losses/test_pnp_loss.py +++ b/tests/losses/test_pnp_loss.py @@ -130,3 +130,11 @@ def test_pnp_loss(self): with self.assertRaises(ValueError): PNPLoss(b, alpha, anneal, "PNP") + + def test_negatives_that_have_no_positives(self): + loss_func = PNPLoss() + labels = torch.tensor([1, 1, 2], device=TEST_DEVICE) + for dtype in TEST_DTYPES: + embeddings = torch.randn(3, 32, dtype=dtype, device=TEST_DEVICE) + loss = loss_func(embeddings, labels) + self.assertTrue(not torch.isnan(loss)) diff --git a/tests/reducers/test_setting_reducers.py b/tests/reducers/test_setting_reducers.py index 202145b5..8667bae0 100644 --- a/tests/reducers/test_setting_reducers.py +++ b/tests/reducers/test_setting_reducers.py @@ -18,7 +18,7 @@ def test_setting_reducers(self): ]: L = loss(reducer=reducer) if isinstance(L, TripletMarginLoss): - assert type(L.reducer) == type(reducer) + assert type(L.reducer) is type(reducer) else: for v in L.reducer.reducers.values(): - assert type(v) == type(reducer) + assert type(v) is type(reducer)