diff --git a/docs/imgs/smooth_ap_approx_equation.png b/docs/imgs/smooth_ap_approx_equation.png new file mode 100644 index 00000000..0b37c27a Binary files /dev/null and b/docs/imgs/smooth_ap_approx_equation.png differ diff --git a/docs/imgs/smooth_ap_loss_equation.png b/docs/imgs/smooth_ap_loss_equation.png new file mode 100644 index 00000000..25f061c8 Binary files /dev/null and b/docs/imgs/smooth_ap_loss_equation.png differ diff --git a/docs/imgs/smooth_ap_sigmoid_equation.png b/docs/imgs/smooth_ap_sigmoid_equation.png new file mode 100644 index 00000000..7153f2a3 Binary files /dev/null and b/docs/imgs/smooth_ap_sigmoid_equation.png differ diff --git a/docs/losses.md b/docs/losses.md index c4126f0d..079cf4c2 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs): * **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. * **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```. +## SmoothAPLoss +[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank} + +```python +losses.SmoothAPLoss( + margin=0.01, + **kwargs +) +``` + +**Equations**: + +![smooth_ap_loss_equation1](imgs/smooth_ap_sigmoid_equation.png){: style="height:100px"} +![smooth_ap_loss_equation2](imgs/smooth_ap_approx_equation.png){: style="height:100px"} +![smooth_ap_loss_equation3](imgs/smooth_ap_loss_equation.png){: style="height:100px"} + + +**Parameters**: + +* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations. + + +**Other info**: + +* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using `samplers.MPerClassSampler` and setting the `batch_size` and `m` hyperparameters. + +**Default distance**: + + - [```CosineSimilarity()```](distances.md#cosinesimilarity) + - This is the only compatible distance. + ## SoftTripleLoss [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank} ```python diff --git a/src/pytorch_metric_learning/distances/dot_product_similarity.py b/src/pytorch_metric_learning/distances/dot_product_similarity.py index 2e0b4b01..74be22f5 100644 --- a/src/pytorch_metric_learning/distances/dot_product_similarity.py +++ b/src/pytorch_metric_learning/distances/dot_product_similarity.py @@ -9,7 +9,7 @@ def __init__(self, **kwargs): assert self.is_inverted def compute_mat(self, query_emb, ref_emb): - return torch.matmul(query_emb, ref_emb.t()) + return torch.matmul(query_emb, ref_emb.transpose(-1, -2)) def pairwise_distance(self, query_emb, ref_emb): return torch.sum(query_emb * ref_emb, dim=1) diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index ba653cda..96403a3b 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -30,6 +30,7 @@ from .ranked_list_loss import RankedListLoss from .self_supervised_loss import SelfSupervisedLoss from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss +from .smooth_ap import SmoothAPLoss from .soft_triple_loss import SoftTripleLoss from .sphereface_loss import SphereFaceLoss from .subcenter_arcface_loss import SubCenterArcFaceLoss diff --git a/src/pytorch_metric_learning/losses/smooth_ap.py b/src/pytorch_metric_learning/losses/smooth_ap.py new file mode 100644 index 00000000..b0e441f3 --- /dev/null +++ b/src/pytorch_metric_learning/losses/smooth_ap.py @@ -0,0 +1,103 @@ +import torch +import torch.nn.functional as F + +from ..distances import CosineSimilarity +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction + + +class SmoothAPLoss(BaseMetricLossFunction): + """ + Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163 + """ + + def __init__(self, temperature=0.01, **kwargs): + super().__init__(**kwargs) + c_f.assert_distance_type(self, CosineSimilarity) + self.temperature = temperature + + def get_default_distance(self): + return CosineSimilarity() + + # Implementation is based on the original repository: + # https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87 + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + # The loss expects labels such that there is the same number of elements for each class + # The number of classes is not important, nor their order, but the number of elements must be the same, eg. + # + # The following label is valid: + # [ A,A,A, B,B,B, C,C,C ] + # The following label is NOT valid: + # [ B,B,B A,A,A,A, C,C,C ] + # + c_f.labels_required(labels) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + + counts = torch.bincount(labels) + nonzero_indices = torch.nonzero(counts, as_tuple=True)[0] + nonzero_counts = counts[nonzero_indices] + if nonzero_counts.unique().size(0) != 1: + raise ValueError( + "All classes must have the same number of elements in the labels.\n" + "The given labels have the following number of elements: {}.\n" + "You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format( + nonzero_counts.cpu().tolist() + ) + ) + + batch_size = embeddings.size(0) + num_classes_batch = batch_size // torch.unique(labels).size(0) + + mask = 1.0 - torch.eye(batch_size) + mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1) + + sims = self.distance(embeddings) + + sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1) + sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1) + sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device) + sims_ranks = torch.sum(sims_sigm, dim=-1) + 1 + + xs = embeddings.view( + num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1) + ) + pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch) + pos_mask = ( + pos_mask.unsqueeze(dim=0) + .unsqueeze(dim=0) + .repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1) + ) + + # Circumvent the shape check in forward method + xs_norm = self.distance.maybe_normalize(xs, dim=-1) + sims_pos = self.distance.compute_mat(xs_norm, xs_norm) + + sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat( + 1, 1, batch_size // num_classes_batch, 1 + ) + sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2) + + sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to( + sims_diff.device + ) + sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1 + + g = batch_size // num_classes_batch + ap = torch.zeros(batch_size).to(embeddings.device) + for i in range(num_classes_batch): + for j in range(g): + pos_rank = sims_pos_ranks[i, j] + all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g] + ap[i * g + j] = torch.sum(pos_rank / all_rank) / g + + miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype) + loss = (1 - ap) * miner_weights + + return { + "ap_loss": { + "losses": loss, + "indices": c_f.torch_arange_from_size(loss), + "reduction_type": "element", + } + } diff --git a/tests/losses/test_smooth_ap_loss.py b/tests/losses/test_smooth_ap_loss.py new file mode 100644 index 00000000..422e9423 --- /dev/null +++ b/tests/losses/test_smooth_ap_loss.py @@ -0,0 +1,191 @@ +import unittest + +import torch +import torch.nn.functional as F + +from pytorch_metric_learning.losses import SmoothAPLoss + +from .. import TEST_DEVICE, TEST_DTYPES + +HYPERPARAMETERS = { + "temp": 0.01, + "batch_size": 60, + "num_id": 6, + "feat_dims": 256, +} +TEST_SEEDS = [42, 1234, 5642, 9999, 3459] + + +# Original implementation of the SmoothAP loss taken from: +# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py +def sigmoid(tensor, temp=1.0): + """temperature controlled sigmoid + + takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp + """ + exponent = -tensor / temp + # clamp the input tensor for stability + exponent = torch.clamp(exponent, min=-50, max=50) + y = 1.0 / (1.0 + torch.exp(exponent)) + return y + + +def compute_aff(x): + """computes the affinity matrix between an input vector and itself""" + return torch.mm(x, x.t()) + + +class SmoothAP(torch.nn.Module): + """PyTorch implementation of the Smooth-AP loss. + + implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns + the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must + have the same number of instances represented in the mini-batch and must be ordered sequentially by class. + + e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like: + + labels = ( A, A, A, B, B, B, C, C, C) + + (the order of the classes however does not matter) + + For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the + mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the + same class. The loss returns the average Smooth-AP across all instances in the mini-batch. + + Args: + anneal : float + the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature + results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function. + batch_size : int + the batch size being used during training. + num_id : int + the number of different classes that are represented in the batch. + feat_dims : int + the dimension of the input feature embeddings + + Shape: + - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor) + - Output: scalar + + Examples:: + + >>> loss = SmoothAP(0.01, 60, 6, 256) + >>> input = torch.randn(60, 256, requires_grad=True).to("cuda:0") + >>> output = loss(input) + >>> output.backward() + """ + + def __init__(self, anneal, batch_size, num_id, feat_dims): + """ + Parameters + ---------- + anneal : float + the temperature of the sigmoid that is used to smooth the ranking function + batch_size : int + the batch size being used + num_id : int + the number of different classes that are represented in the batch + feat_dims : int + the dimension of the input feature embeddings + """ + super(SmoothAP, self).__init__() + + assert batch_size % num_id == 0 + + self.anneal = anneal + self.batch_size = batch_size + self.num_id = num_id + self.feat_dims = feat_dims + + def forward(self, preds): + """Forward pass for all input predictions: preds - (batch_size x feat_dims)""" + + # ------ differentiable ranking of all retrieval set ------ + # compute the mask which ignores the relevance score of the query to itself + mask = 1.0 - torch.eye(self.batch_size) + mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) + # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors + sim_all = compute_aff(preds) + sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1) + # compute the difference matrix + sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1) + # pass through the sigmoid + sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.to(TEST_DEVICE) + # compute the rankings + sim_all_rk = torch.sum(sim_sg, dim=-1) + 1 + + # ------ differentiable ranking of only positive set in retrieval set ------ + # compute the mask which only gives non-zero weights to the positive set + xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims) + pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id)) + pos_mask = ( + pos_mask.unsqueeze(dim=0) + .unsqueeze(dim=0) + .repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1) + ) + + # compute the relevance scores + sim_pos = torch.bmm(xs, xs.permute(0, 2, 1)) + sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat( + 1, 1, int(self.batch_size / self.num_id), 1 + ) + # compute the difference matrix + sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2) + # pass through the sigmoid + sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.to(TEST_DEVICE) + # compute the rankings of the positive set + sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1 + + # sum the values of the Smooth-AP for all instances in the mini-batch + ap = torch.zeros(1).to(TEST_DEVICE) + group = int(self.batch_size / self.num_id) + for ind in range(self.num_id): + pos_divide = torch.sum( + sim_pos_rk[ind] + / ( + sim_all_rk[ + (ind * group) : ((ind + 1) * group), + (ind * group) : ((ind + 1) * group), + ] + ) + ) + ap = ap + ((pos_divide / group) / self.batch_size) + + return 1 - ap + + +class TestSmoothAPLoss(unittest.TestCase): + def test_smooth_ap_loss(self): + for dtype in TEST_DTYPES: + for seed in TEST_SEEDS: + torch.manual_seed(seed) + loss = SmoothAP( + HYPERPARAMETERS["temp"], + HYPERPARAMETERS["batch_size"], + HYPERPARAMETERS["num_id"], + HYPERPARAMETERS["feat_dims"], + ) + rand_tensor = ( + torch.randn( + HYPERPARAMETERS["batch_size"], + HYPERPARAMETERS["feat_dims"], + requires_grad=True, + ) + .to(TEST_DEVICE) + .to(dtype) + ) + # The original code uses a model that normalizes the output vector + input_ = F.normalize(rand_tensor, p=2.0, dim=-1) + output = loss(input_) + + loss2 = SmoothAPLoss(temperature=HYPERPARAMETERS["temp"]) + # The original code assumes the label is in this format + labels = [] + for i in range( + HYPERPARAMETERS["batch_size"] // HYPERPARAMETERS["num_id"] + ): + labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])]) + + labels = torch.tensor(labels) + output2 = loss2.forward(rand_tensor, labels) + self.assertTrue(torch.isclose(output, output2))