diff --git a/README.md b/README.md
index f8eb0c7b..9c29da8e 100644
--- a/README.md
+++ b/README.md
@@ -18,16 +18,16 @@
## News
+**July 25**: v2.3.0
+- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
+- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).
+
**June 18**: v2.2.0
- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss).
- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss).
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0).
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).
-**April 5**: v2.1.0
-- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss)
-- Thanks you [interestingzhuo](https://github.com/interestingzhuo).
-
## Documentation
- [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/)
@@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests!
| Contributor | Highlights |
| -- | -- |
-|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
+|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper|
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
@@ -246,6 +246,7 @@ Thanks to the contributors who made pull requests!
| [layumi](https://github.com/layumi) | [InstanceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) |
| [NoTody](https://github.com/NoTody) | Helped add `ref_emb` and `ref_labels` to the distributed wrappers. |
| [ElisonSherton](https://github.com/ElisonSherton) | Fixed an edge case in ArcFaceLoss. |
+| [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss |
| [z1w](https://github.com/z1w) | |
| [thinline72](https://github.com/thinline72) | |
| [tpanum](https://github.com/tpanum) | |
@@ -259,6 +260,7 @@ Thanks to the contributors who made pull requests!
| [michaeldeyzel](https://github.com/michaeldeyzel) | |
| [HSinger04](https://github.com/HSinger04) | |
| [rheum](https://github.com/rheum) | |
+| [bot66](https://github.com/bot66) | |
diff --git a/docs/losses.md b/docs/losses.md
index d5c1f4b4..85509cf9 100644
--- a/docs/losses.md
+++ b/docs/losses.md
@@ -807,6 +807,27 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse
- [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf){target=_blank}
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank}
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank}
+
+??? "How exactly is the NTXentLoss computed?"
+
+ In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by itself and all negative pairs in the batch that have the same "anchor" embedding (`k_i in K`).
+
+ - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor.
+
+ Given `embeddings` with corresponding `labels`, positive pairs `(embeddings[i], embeddings[j])` are defined when `labels[i] == labels[j]`. Now let's look at an example loss calculation:
+
+ Consider `labels = [0, 0, 1, 2]`. Two losses will be computed:
+
+ * A positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`.
+
+ * A positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`.
+
+ Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used.
+
+ Note that an anchor can belong to multiple positive pairs if its label is present multiple times in `labels`.
+
+ Are you trying to use `NTXentLoss` for self-supervised learning? Specifically, do you have two sets of embeddings which are derived from data that are augmented versions of each other? If so, you can skip the step of creating the `labels` array, by wrapping `NTXentLoss` with [`SelfSupervisedLoss`](losses.md#selfsupervisedloss).
+
```python
losses.NTXentLoss(temperature=0.07, **kwargs)
```
diff --git a/src/pytorch_metric_learning/losses/pnp_loss.py b/src/pytorch_metric_learning/losses/pnp_loss.py
index 50996e9f..107244ad 100644
--- a/src/pytorch_metric_learning/losses/pnp_loss.py
+++ b/src/pytorch_metric_learning/losses/pnp_loss.py
@@ -68,8 +68,8 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
else:
raise Exception(f"variant <{self.variant}> not available!")
- loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1)
- loss = torch.sum(loss) / N
+ loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1)
+ loss = torch.sum(loss) / torch.sum(safe_N)
if self.variant == "Dq":
loss = 1 - loss
diff --git a/tests/losses/test_pnp_loss.py b/tests/losses/test_pnp_loss.py
index 401cb6dc..6a1ad3ea 100644
--- a/tests/losses/test_pnp_loss.py
+++ b/tests/losses/test_pnp_loss.py
@@ -130,3 +130,11 @@ def test_pnp_loss(self):
with self.assertRaises(ValueError):
PNPLoss(b, alpha, anneal, "PNP")
+
+ def test_negatives_that_have_no_positives(self):
+ loss_func = PNPLoss()
+ labels = torch.tensor([1, 1, 2], device=TEST_DEVICE)
+ for dtype in TEST_DTYPES:
+ embeddings = torch.randn(3, 32, dtype=dtype, device=TEST_DEVICE)
+ loss = loss_func(embeddings, labels)
+ self.assertTrue(not torch.isnan(loss))
diff --git a/tests/reducers/test_setting_reducers.py b/tests/reducers/test_setting_reducers.py
index 202145b5..8667bae0 100644
--- a/tests/reducers/test_setting_reducers.py
+++ b/tests/reducers/test_setting_reducers.py
@@ -18,7 +18,7 @@ def test_setting_reducers(self):
]:
L = loss(reducer=reducer)
if isinstance(L, TripletMarginLoss):
- assert type(L.reducer) == type(reducer)
+ assert type(L.reducer) is type(reducer)
else:
for v in L.reducer.reducers.values():
- assert type(v) == type(reducer)
+ assert type(v) is type(reducer)