diff --git a/.gitignore b/.gitignore index ef79da47..45431d45 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist/ *.egg-info/ site/ venv/ +**/.vscode .ipynb_checkpoints examples/notebooks/dataset examples/notebooks/CIFAR10_Dataset diff --git a/CONTENTS.md b/CONTENTS.md index 6c4bfef3..7839ab9c 100644 --- a/CONTENTS.md +++ b/CONTENTS.md @@ -17,16 +17,19 @@ | [**CosFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#cosfaceloss) | - [CosFace: Large Margin Cosine Loss for Deep Face Recognition](https://arxiv.org/pdf/1801.09414.pdf)
- [Additive Margin Softmax for Face Verification](https://arxiv.org/pdf/1801.05599.pdf) | [**FastAPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#fastaploss) | [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf) | [**GeneralizedLiftedStructureLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#generalizedliftedstructureloss) | [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737.pdf) +| [**HistogramLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) | [Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf) | [**InstanceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) | [Dual-Path Convolutional Image-Text Embeddings with Instance Loss](https://arxiv.org/pdf/1711.05535.pdf) | [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf) | [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf) | [**LiftedStructreLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#liftedstructureloss) | [Deep Metric Learning via Lifted Structured Feature Embedding](https://arxiv.org/pdf/1511.06452.pdf) +| [**ManifoldLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) | [Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf) | [**MarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#marginloss) | [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf) | [**MultiSimilarityLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#multisimilarityloss) | [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) | [**NCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ncaloss) | [Neighbourhood Components Analysis](https://www.cs.toronto.edu/~hinton/absps/nca.pdf) | [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf)
- [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf) | [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf) | [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf)
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf)
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709) +| [**P2SGradLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) | [P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479) | [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf) | [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf) | [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf) diff --git a/README.md b/README.md index 26d568ae..f8eb0c7b 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,15 @@ ## News +**June 18**: v2.2.0 +- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss). +- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss). +- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0). +- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). + **April 5**: v2.1.0 - Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) -- Thanks to contributor [interestingzhuo](https://github.com/interestingzhuo). - -**January 29**: v2.0.0 -- Added SelfSupervisedLoss, plus various API improvements. See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.0.0). -- Thanks to contributor [cwkeam](https://github.com/cwkeam). +- Thanks you [interestingzhuo](https://github.com/interestingzhuo). ## Documentation @@ -225,6 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | @@ -273,6 +276,7 @@ This library contains code that has been adapted and modified from the following - https://github.com/ronekko/deep_metric_learning - https://github.com/tjddus9597/Proxy-Anchor-CVPR2020 - http://kaizhao.net/regularface +- https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts ### Logo Thanks to [Jeff Musgrave](https://www.designgenius.ca/) for designing the logo. diff --git a/docs/losses.md b/docs/losses.md index b0719258..d5c1f4b4 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -424,6 +424,26 @@ losses.InstanceLoss(gamma=64, **kwargs) * **gamma**: The cosine similarity matrix is scaled by this amount. +## HistogramLoss +[Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf) +```python +losses.HistogramLoss(n_bins=None, delta=None) +``` + +**Parameters**: + +* **n_bins**: The number of bins used to construct the histogram. Default is 100 when both `n_bins` and `delta` are `None`. +* **delta**: The mesh of the uniform partition of the interval [-1, 1] used to construct the histogram. If not set the value of n_bins will be used. + +**Default distance**: + + - [```CosineSimilarity()```](distances.md#cosinesimilarity) + +**Default reducer**: + + - This loss returns an **already reduced** loss. + + ## IntraPairVarianceLoss [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf){target=_blank} ```python @@ -545,6 +565,57 @@ losses.LiftedStructureLoss(neg_margin=1, pos_margin=0, **kwargs): * **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. +## ManifoldLoss + +[Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf) + +```python +losses.ManifoldLoss( + l: int, + K: int = 50, + lambdaC: float = 1.0, + alpha: float = 0.8, + margin: float = 5e-4, + **kwargs + ) +``` + +**Parameters** + +- **l**: embedding size. + +- **K**: number of proxies. + +- **lambdaC**: regularization weight. Used in the formula `loss = intrinsic_loss + lambdaC*context_loss`. + If `lambdaC=0`, then it uses only the intrinsic loss. If `lambdaC=np.inf`, then it uses only the context loss. + +- **alpha**: parameter of the Random Walk. Must be in the range `(0,1)`. It specifies the amount of similarity between neighboring nodes. + +- **margin**: margin used in the calculation of the loss. + + +Example usage: +```python +loss_fn = ManifoldLoss(128) + +# use random cluster centers +loss = loss_fn(embeddings) +# or specify indices of embeddings to use as cluster centers +loss = loss_fn(embeddings, indices_tuple=indices) +``` + +**Important notes** + +`labels`, `ref_emb`, and `ref_labels` are not supported for this loss function. + +In addition, `indices_tuple` is **not** for the output of miners. Instead, it is for a list of indices of embeddings to be used as cluster centers. + + +**Default reducer**: + + - This loss returns an **already reduced** loss. + + ## MarginLoss [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf){target=_blank} ```python @@ -761,6 +832,37 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. + +## P2SGradLoss +[P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479) +```python +losses.P2SGradLoss(descriptors_dim, num_classes, **kwargs) +``` + +**Parameters** + +- **descriptors_dim**: The embedding size. + +- **num_classes**: The number of classes in your training dataset. + + +Example usage: +```python +loss_fn = P2SGradLoss(128, 10) +loss = loss_fn(embeddings, labels) +``` + +**Important notes** + +`indices_tuple`, `ref_emb`, and `ref_labels` are not supported for this loss function. + + +**Default reducer**: + + - This loss returns an **already reduced** loss. + + + ## PNPLoss [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank} ```python @@ -849,14 +951,31 @@ loss_optimizer.step() ## SelfSupervisedLoss -A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. `SelfSupervisedLoss` automates this. +A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. + +`SelfSupervisedLoss` is a wrapper that takes care of this by creating labels internally. It assumes that: + +- `ref_emb[i]` is an augmented version of `embeddings[i]`. +- `ref_emb[i]` is the only augmented version of `embeddings[i]` in the batch. ```python +losses.SelfSupervisedLoss(loss, symmetric=True, **kwargs) +``` + +**Parameters**: + +* **loss**: The loss function to be wrapped. +* **symmetric**: If `True`, then the embeddings in both `embeddings` and `ref_emb` are used as anchors. If `False`, then only the embeddings in `embeddings` are used as anchors. + +Example usage: + +``` loss_fn = losses.TripletMarginLoss() loss_fn = SelfSupervisedLoss(loss_fn) loss = loss_fn(embeddings, ref_emb) ``` + ??? "Supported Loss Functions" - [AngularLoss](losses.md#angularloss) - [CircleLoss](losses.md#circleloss) diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 8a124bf6..55e47090 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.2.0" +__version__ = "2.3.0" diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 10e841f1..a0ba7407 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -8,6 +8,7 @@ from .cross_batch_memory import CrossBatchMemory from .fast_ap_loss import FastAPLoss from .generic_pair_loss import GenericPairLoss +from .histogram_loss import HistogramLoss from .instance_loss import InstanceLoss from .intra_pair_variance_loss import IntraPairVarianceLoss from .large_margin_softmax_loss import LargeMarginSoftmaxLoss diff --git a/src/pytorch_metric_learning/losses/histogram_loss.py b/src/pytorch_metric_learning/losses/histogram_loss.py new file mode 100644 index 00000000..44899fcb --- /dev/null +++ b/src/pytorch_metric_learning/losses/histogram_loss.py @@ -0,0 +1,80 @@ +import torch + +from ..distances import CosineSimilarity +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction + + +def filter_pairs(*tensors: torch.Tensor): + t = torch.stack(tensors) + t, _ = torch.sort(t, dim=0) + t = torch.unique(t, dim=1) + return t.tolist() + + +class HistogramLoss(BaseMetricLossFunction): + def __init__(self, n_bins: int = None, delta: float = None, **kwargs): + super().__init__(**kwargs) + if delta is not None and n_bins is not None: + assert ( + delta == 2 / n_bins + ), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}" + + if delta is None and n_bins is None: + n_bins = 100 + + self.delta = delta if delta is not None else 2 / n_bins + self.add_to_recordable_attributes(name="delta", is_stat=True) + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + c_f.labels_or_indices_tuple_required(labels, indices_tuple) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + indices_tuple = lmu.convert_to_triplets( + indices_tuple, labels, ref_labels, t_per_anchor="all" + ) + anchor_idx, positive_idx, negative_idx = indices_tuple + if len(anchor_idx) == 0: + return self.zero_losses() + mat = self.distance(embeddings, ref_emb) + + anchor_positive_idx = filter_pairs(anchor_idx, positive_idx) + anchor_negative_idx = filter_pairs(anchor_idx, negative_idx) + ap_dists = mat[anchor_positive_idx] + an_dists = mat[anchor_negative_idx] + + p_pos = self.compute_density(ap_dists) + phi = torch.cumsum(p_pos, dim=0) + + p_neg = self.compute_density(an_dists) + return { + "loss": { + "losses": torch.sum(p_neg * phi), + "indices": None, + "reduction_type": "already_reduced", + } + } + + def compute_density(self, distances): + size = distances.size(0) + r_star = torch.floor( + (distances.float() + 1) / self.delta + ) # Indices of the bins containing the values of the distances + r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long) + + delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta + delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta + delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances) + delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances) + + density = torch.zeros(round(1 + 2 / self.delta)) + density = c_f.to_device(density, tensor=distances, dtype=distances.dtype) + + # For each node sum the contributions of the bins whose ending node is this one + density.scatter_add_(0, r_star + 1, delta_ijr_a) + # For each node sum the contributions of the bins whose starting node is this one + density.scatter_add_(0, r_star, delta_ijr_b) + return density / size + + def get_default_distance(self): + return CosineSimilarity() diff --git a/tests/losses/test_histogram_loss.py b/tests/losses/test_histogram_loss.py new file mode 100644 index 00000000..aeeb54af --- /dev/null +++ b/tests/losses/test_histogram_loss.py @@ -0,0 +1,170 @@ +import unittest + +import torch +from numpy.testing import assert_almost_equal + +from pytorch_metric_learning.losses import HistogramLoss +from pytorch_metric_learning.utils import common_functions as c_f + +from .. import TEST_DEVICE, TEST_DTYPES + + +###################################### +#######ORIGINAL IMPLEMENTATION######## +###################################### +# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py. +# This code is copied from the official PyTorch implementation +# so that we can make sure our implementation returns the same result. +# Some minor changes were made to avoid errors during testing. +# Every change in the original code is reported and explained. +class OriginalImplementationHistogramLoss(torch.nn.Module): + def __init__(self, num_steps, cuda=True): + super(OriginalImplementationHistogramLoss, self).__init__() + self.step = 2 / (num_steps - 1) + self.eps = 1 / num_steps + self.cuda = cuda + self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1) + self.tsize = self.t.size()[0] + if self.cuda: + self.t = self.t.cuda() + + def forward(self, features, classes): + def histogram(inds, size): + s_repeat_ = s_repeat.clone() + inds = c_f.to_device(inds, tensor=s_repeat_floor) # Added to avoid errors + self.t = c_f.to_device( + self.t, tensor=s_repeat_floor + ) # Added to avoid errors + indsa = ( + (s_repeat_floor - (self.t - self.step) > -self.eps) + & (s_repeat_floor - (self.t - self.step) < self.eps) + & inds + ) + assert ( + indsa.nonzero().size()[0] == size + ), "Another number of bins should be used" + zeros = torch.zeros((1, indsa.size()[1])).to( + device=indsa.device, dtype=torch.uint8 + ) + if self.cuda: + zeros = zeros.cuda() + indsb = torch.cat((indsa, zeros))[1:, :].to( + dtype=torch.bool + ) # Added to avoid bug with masks of uint8 + s_repeat_[~(indsb | indsa)] = 0 + # indsa corresponds to the first condition of the second equation of the paper + self.t = self.t.to( + dtype=s_repeat_.dtype + ) # Added to avoid errors when using Half precision + s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step + # indsb corresponds to the second condition of the second equation of the paper + s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step + + return s_repeat_.sum(1) / size + + classes_size = classes.size()[0] + classes_eq = ( + classes.repeat(classes_size, 1) + == classes.view(-1, 1).repeat(1, classes_size) + ).data + dists = torch.mm(features, features.transpose(0, 1)) + assert ( + (dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item() + ) == 0, "L2 normalization should be used" + s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte() + if self.cuda: + s_inds = s_inds.cuda() + classes_eq = classes_eq.to( + device=s_inds.device + ) # Added to avoid errors when using only cpu + pos_inds = classes_eq[s_inds].repeat(self.tsize, 1) + neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1) + pos_size = classes_eq[s_inds].sum().item() + neg_size = (~classes_eq[s_inds]).sum().item() + s = dists[s_inds].view(1, -1) + s_repeat = s.repeat(self.tsize, 1) + s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float() + + histogram_pos = histogram(pos_inds, pos_size) + assert_almost_equal( + histogram_pos.sum().item(), + 1, + decimal=1, + err_msg="Not good positive histogram", + verbose=True, + ) + histogram_neg = histogram(neg_inds, neg_size) + assert_almost_equal( + histogram_neg.sum().item(), + 1, + decimal=1, + err_msg="Not good negative histogram", + verbose=True, + ) + histogram_pos_repeat = histogram_pos.view(-1, 1).repeat( + 1, histogram_pos.size()[0] + ) + histogram_pos_inds = torch.tril( + torch.ones(histogram_pos_repeat.size()), -1 + ).byte() + if self.cuda: + histogram_pos_inds = histogram_pos_inds.cuda() + histogram_pos_repeat[histogram_pos_inds] = 0 + histogram_pos_cdf = histogram_pos_repeat.sum(0) + loss = torch.sum(histogram_neg * histogram_pos_cdf) + + return loss + + +class TestHistogramLoss(unittest.TestCase): + def test_histogram_loss(self): + batch_size = 32 + embedding_size = 64 + for dtype in TEST_DTYPES: + num_steps = 5 if dtype == torch.float16 else 21 + num_bins = num_steps - 1 + loss_func = HistogramLoss(n_bins=num_bins) + original_loss_func = OriginalImplementationHistogramLoss( + num_steps=num_steps, cuda=False + ) + + # test multiple times + for _ in range(2): + embeddings = torch.randn( + batch_size, + embedding_size, + requires_grad=True, + dtype=dtype, + ).to(TEST_DEVICE) + labels = torch.randint(0, 5, size=(batch_size,)) + + loss = loss_func(embeddings, labels) + correct_loss = original_loss_func( + torch.nn.functional.normalize(embeddings), labels + ) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + loss.backward() + + def test_with_no_valid_triplets(self): + loss_func = HistogramLoss(n_bins=4) + for dtype in TEST_DTYPES: + embeddings = torch.randn( + 5, + 32, + requires_grad=True, + dtype=dtype, + ).to(TEST_DEVICE) + labels = torch.LongTensor([0, 1, 2, 3, 4]) + loss = loss_func(embeddings, labels) + self.assertEqual(loss, 0) + loss.backward() + + def test_assertion_raises(self): + with self.assertRaises(AssertionError): + _ = HistogramLoss(n_bins=1, delta=0.5) + + with self.assertRaises(AssertionError): + _ = HistogramLoss(n_bins=10, delta=0.4)