diff --git a/.github/workflows/base_test_workflow.yml b/.github/workflows/base_test_workflow.yml index ff965afd..52fd4429 100644 --- a/.github/workflows/base_test_workflow.yml +++ b/.github/workflows/base_test_workflow.yml @@ -13,15 +13,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] - pytorch-version: [1.6, 1.11] - torchvision-version: [0.7.0, 0.12.0] - with-collect-stats: [false] - exclude: - - pytorch-version: 1.6 - torchvision-version: 0.12.0 - - pytorch-version: 1.11 - torchvision-version: 0.7.0 + include: + - python-version: 3.8 + pytorch-version: 1.6 + torchvision-version: 0.7 + - python-version: 3.9 + pytorch-version: 2.1 + torchvision-version: 0.16 steps: - uses: actions/checkout@v2 @@ -34,6 +32,8 @@ jobs: pip install .[with-hooks-cpu] pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall pip install --upgrade protobuf==3.20.1 + pip install six + pip install packaging - name: Run unit tests run: | TEST_DTYPES=float32,float64 TEST_DEVICE=cpu WITH_COLLECT_STATS=${{ matrix.with-collect-stats }} python -m unittest discover -t . -s tests/${{ inputs.module-to-test }} diff --git a/docs/losses.md b/docs/losses.md index 85509cf9..ca7e1c67 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -345,6 +345,19 @@ The queue can be cleared like this: loss_fn.reset_queue() ``` +## DynamicSoftMarginLoss +[Learning Local Descriptors With a CDF-Based Dynamic Soft Margin](https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf) +```python +losses.DynamicSoftMarginLoss(min_val=-2.0, num_bins=10, momentum=0.01, **kwargs) +``` + +**Parameters**: + +* **min_val**: minimum significative value for `d_pos - d_neg` +* **num_bins**: number of equally spaced bins for the partition of the interval `[min_val, ∞]` +* **momentum**: weight assigned to the histogram computed from the current batch + + ## FastAPLoss [Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf){target=_blank} @@ -969,6 +982,20 @@ loss_optimizer.step() * **loss**: The loss per element in the batch, that results in a non zero exponent in the cross entropy expression. Reduction type is ```"element"```. +## RankedListLoss +[Ranked List Loss for Deep Metric Learning](https://arxiv.org/abs/1903.03238) +```python +losses.RankedListLoss(margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs) +``` + +**Parameters**: + +* **margin** (float): margin between positive and negative set +* **imbalance** (float): tradeoff between positive and negative sets. As the name suggests this takes into account + the imbalance between positive and negative samples in the dataset +* **alpha** (float): smallest distance between negative points +* **Tp & Tn** (float): temperatures for, respectively, positive and negative pairs weighting. + ## SelfSupervisedLoss diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 55e47090..3d67cd6b 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.3.0" +__version__ = "2.4.0" diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index a0ba7407..ba653cda 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -6,6 +6,7 @@ from .contrastive_loss import ContrastiveLoss from .cosface_loss import CosFaceLoss from .cross_batch_memory import CrossBatchMemory +from .dynamic_soft_margin_loss import DynamicSoftMarginLoss from .fast_ap_loss import FastAPLoss from .generic_pair_loss import GenericPairLoss from .histogram_loss import HistogramLoss @@ -26,6 +27,7 @@ from .pnp_loss import PNPLoss from .proxy_anchor_loss import ProxyAnchorLoss from .proxy_losses import ProxyNCALoss +from .ranked_list_loss import RankedListLoss from .self_supervised_loss import SelfSupervisedLoss from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss from .soft_triple_loss import SoftTripleLoss diff --git a/src/pytorch_metric_learning/losses/dynamic_soft_margin_loss.py b/src/pytorch_metric_learning/losses/dynamic_soft_margin_loss.py new file mode 100644 index 00000000..c1f30d3e --- /dev/null +++ b/src/pytorch_metric_learning/losses/dynamic_soft_margin_loss.py @@ -0,0 +1,125 @@ +import numpy as np +import torch + +from ..distances import LpDistance +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction + + +def find_hard_negatives(dmat): + """ + a = A * P' + A: N * ndim + P: N * ndim + + a1p1 a1p2 a1p3 a1p4 ... + a2p1 a2p2 a2p3 a2p4 ... + a3p1 a3p2 a3p3 a3p4 ... + a4p1 a4p2 a4p3 a4p4 ... + ... ... ... ... + """ + + pos = dmat.diag() + dmat.fill_diagonal_(np.inf) + + min_a, _ = torch.min(dmat, dim=0) + min_p, _ = torch.min(dmat, dim=1) + neg = torch.min(min_a, min_p) + return pos, neg + + +class DynamicSoftMarginLoss(BaseMetricLossFunction): + r"""Loss function with dynamical margin parameter introduced in https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf + + Args: + min_val: minimum significative value for `d_pos - d_neg` + num_bins: number of equally spaced bins for the partition of the interval [min_val, :math:`+\infty`] + momentum: weight assigned to the histogram computed from the current batch + """ + + def __init__(self, min_val=-2.0, num_bins=10, momentum=0.01, **kwargs): + super().__init__(**kwargs) + c_f.assert_distance_type(self, LpDistance, normalize_embeddings=True, p=2) + self.min_val = min_val + self.num_bins = int(num_bins) + self.delta = 2 * abs(min_val) / num_bins + self.momentum = momentum + self.hist_ = torch.zeros((num_bins,)) + self.add_to_recordable_attributes(list_of_names=["num_bins"], is_stat=False) + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + self.hist_ = c_f.to_device( + self.hist_, tensor=embeddings, dtype=embeddings.dtype + ) + + if labels is None: + loss = self.compute_loss_without_labels( + embeddings, labels, indices_tuple, ref_emb, ref_labels + ) + else: + loss = self.compute_loss_with_labels( + embeddings, labels, indices_tuple, ref_emb, ref_labels + ) + + if len(loss) == 0: + return self.zero_losses() + + self.update_histogram(loss) + loss = self.weigh_loss(loss) + loss = loss.mean() + return { + "loss": { + "losses": loss, + "indices": None, + "reduction_type": "already_reduced", + } + } + + def compute_loss_without_labels( + self, embeddings, labels, indices_tuple, ref_emb, ref_labels + ): + mat = self.distance(embeddings, ref_emb) + r, c = mat.size() + + d_pos = torch.zeros(max(r, c)) + d_pos = c_f.to_device(d_pos, tensor=embeddings, dtype=embeddings.dtype) + d_pos[: min(r, c)] = mat.diag() + mat.fill_diagonal_(np.inf) + + min_a, min_p = torch.zeros(max(r, c)), torch.zeros( + max(r, c) + ) # Check for unequal number of anchors and positives + min_a = c_f.to_device(min_a, tensor=embeddings, dtype=embeddings.dtype) + min_p = c_f.to_device(min_p, tensor=embeddings, dtype=embeddings.dtype) + min_a[:c], _ = torch.min(mat, dim=0) + min_p[:r], _ = torch.min(mat, dim=1) + + d_neg = torch.min(min_a, min_p) + return d_pos - d_neg + + def compute_loss_with_labels( + self, embeddings, labels, indices_tuple, ref_emb, ref_labels + ): + anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets( + indices_tuple, labels, ref_labels, t_per_anchor="all" + ) # Use all instead of t_per_anchor=1 to be deterministic + mat = self.distance(embeddings, ref_emb) + d_pos, d_neg = mat[anchor_idx, positive_idx], mat[anchor_idx, negative_idx] + return d_pos - d_neg + + def update_histogram(self, data): + idx, alpha = torch.floor((data - self.min_val) / self.delta).to( + dtype=torch.long + ), torch.frac((data - self.min_val) / self.delta) + momentum = self.momentum if self.hist_.sum() != 0 else 1.0 + self.hist_ = torch.scatter_add( + (1.0 - momentum) * self.hist_, 0, idx, momentum * (1 - alpha) + ) + self.hist_ = torch.scatter_add(self.hist_, 0, idx + 1, momentum * alpha) + self.hist_ /= self.hist_.sum() + + def weigh_loss(self, data): + CDF = torch.cumsum(self.hist_, 0) + idx = torch.floor((data - self.min_val) / self.delta).to(dtype=torch.long) + return CDF[idx] * data diff --git a/src/pytorch_metric_learning/losses/histogram_loss.py b/src/pytorch_metric_learning/losses/histogram_loss.py index 44899fcb..8916b780 100644 --- a/src/pytorch_metric_learning/losses/histogram_loss.py +++ b/src/pytorch_metric_learning/losses/histogram_loss.py @@ -25,7 +25,7 @@ def __init__(self, n_bins: int = None, delta: float = None, **kwargs): n_bins = 100 self.delta = delta if delta is not None else 2 / n_bins - self.add_to_recordable_attributes(name="delta", is_stat=True) + self.add_to_recordable_attributes(name="delta", is_stat=False) def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_or_indices_tuple_required(labels, indices_tuple) diff --git a/src/pytorch_metric_learning/losses/pnp_loss.py b/src/pytorch_metric_learning/losses/pnp_loss.py index 50996e9f..107244ad 100644 --- a/src/pytorch_metric_learning/losses/pnp_loss.py +++ b/src/pytorch_metric_learning/losses/pnp_loss.py @@ -68,8 +68,8 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): else: raise Exception(f"variant <{self.variant}> not available!") - loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1) - loss = torch.sum(loss) / N + loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1) + loss = torch.sum(loss) / torch.sum(safe_N) if self.variant == "Dq": loss = 1 - loss diff --git a/src/pytorch_metric_learning/losses/ranked_list_loss.py b/src/pytorch_metric_learning/losses/ranked_list_loss.py new file mode 100644 index 00000000..d3e1d745 --- /dev/null +++ b/src/pytorch_metric_learning/losses/ranked_list_loss.py @@ -0,0 +1,96 @@ +import warnings + +import torch + +from ..distances import LpDistance +from ..utils import common_functions as c_f +from .base_metric_loss_function import BaseMetricLossFunction + + +class RankedListLoss(BaseMetricLossFunction): + r"""Ranked List Loss described in https://arxiv.org/abs/1903.03238 + Default parameters correspond to RLL-Simpler, preferred for exploratory analysis. + + Args: + * margin (float): margin between positive and negative set + * imbalance (float): tradeoff between positive and negative sets. As the name suggests this takes into account + the imbalance between positive and negative samples in the dataset + * alpha (float): smallest distance between negative points + * Tp & Tn (float): temperatures for, respectively, positive and negative pairs weighting. + """ + + def __init__(self, margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs): + super().__init__(**kwargs) + + self.margin = margin + + assert 0 <= imbalance <= 1, "Imbalance must be between 0 and 1" + self.imbalance = imbalance + + if alpha is not None: + self.alpha = alpha + else: + self.alpha = 1 + margin / 2 + + if Tp > 5 or Tn > 5: + warnings.warn( + "Values of Tp or Tn are too high. Too large temperature values may lead to overflow." + ) + + self.Tp = Tp + self.Tn = Tn + self.add_to_recordable_attributes( + list_of_names=["imbalance", "alpha", "margin", "Tp", "Tn"], is_stat=False + ) + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + c_f.labels_required(labels) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + c_f.indices_tuple_not_supported(indices_tuple) + + mat = self.distance(embeddings, embeddings) + # mat.fill_diagonal_(0) + mat = mat - mat * torch.eye(len(mat), device=embeddings.device) + mat = c_f.to_device(mat, device=embeddings.device, dtype=embeddings.dtype) + y = labels.unsqueeze(1) == labels.unsqueeze(0) + + P_star = torch.zeros_like(mat) + N_star = torch.zeros_like(mat) + w_p = torch.zeros_like(mat) + w_n = torch.zeros_like(mat) + + N_star[(~y) * (mat < self.alpha)] = mat[(~y) * (mat < self.alpha)] + y.fill_diagonal_(False) + P_star[y * (mat > (self.alpha - self.margin))] = mat[ + y * (mat > (self.alpha - self.margin)) + ] + + w_p[P_star > 0] = torch.exp( + self.Tp * (P_star[P_star > 0] - (self.alpha - self.margin)) + ) + w_n[N_star > 0] = torch.exp(self.Tn * (self.alpha - N_star[N_star > 0])) + + loss_P = torch.sum( + w_p * (P_star - (self.alpha - self.margin)), dim=1 + ) / torch.sum(w_p + 1e-5, dim=1) + + loss_N = torch.sum(w_n * (self.alpha - N_star), dim=1) / torch.sum( + w_n + 1e-5, dim=1 + ) + + # with torch.no_grad(): + # loss_P[loss_P.isnan()] = 0 + # loss_N[loss_N.isnan()] = 0 + + loss_RLL = (1 - self.imbalance) * loss_P + self.imbalance * loss_N + + return { + "loss": { + "losses": loss_RLL, + "indices": c_f.torch_arange_from_size(loss_RLL), + "reduction_type": "element", + } + } + + def get_default_distance(self): + return LpDistance() diff --git a/tests/losses/test_dynamic_soft_margin_loss.py b/tests/losses/test_dynamic_soft_margin_loss.py new file mode 100644 index 00000000..9eb5853b --- /dev/null +++ b/tests/losses/test_dynamic_soft_margin_loss.py @@ -0,0 +1,392 @@ +import torch +import torch.nn as nn + +from pytorch_metric_learning.utils import common_functions as c_f +from pytorch_metric_learning.utils import loss_and_miner_utils as lmu + + +###################################### +#######ORIGINAL IMPLEMENTATION######## +###################################### +# DIRECTLY COPIED FROM https://github.com/lg-zhang/dynamic-soft-margin-pytorch/blob/master/modules/dynamic_soft_margin.py. +# This code is copied from the official implementation +# so that we can make sure our implementation returns the same result. +# Some minor changes were made to avoid errors during testing. +# Every change in the original code is reported and explained. +def compute_distance_matrix_unit_l2(a, b, eps=1e-6): + """ + computes pairwise Euclidean distance and return a N x N matrix + """ + + dmat = torch.matmul(a, torch.transpose(b, 0, 1)) + dmat = ((1.0 - dmat + eps) * 2.0).pow(0.5) + return dmat + + +def find_hard_negatives(dmat, output_index=True, empirical_thresh=0.0): + """ + a = A * P' + A: N * ndim + P: N * ndim + + a1p1 a1p2 a1p3 a1p4 ... + a2p1 a2p2 a2p3 a2p4 ... + a3p1 a3p2 a3p3 a3p4 ... + a4p1 a4p2 a4p3 a4p4 ... + ... ... ... ... + """ + + r, c = dmat.size() # Correct bug + + if not output_index: + pos = torch.zeros(max(r, c)) # Correct bug + pos[: min(r, c)] = dmat.diag() # Correct bug + + dmat = ( + dmat + torch.eye(r, c).to(dmat.device) * 99999 + ) # filter diagonal # Correct bug + dmat[dmat < empirical_thresh] = 99999 # filter outliers in brown dataset + + # Add following 3 lines to solve a bug + min_a, min_p = torch.zeros(max(r, c)), torch.zeros( + max(r, c) + ) # Check for unequal number of anchors and positives + min_a[:c], _ = torch.min(dmat, dim=0) + min_p[:r], _ = torch.min(dmat, dim=1) + + if not output_index: + neg = torch.min(min_a, min_p) + return pos, neg.to(dtype=pos.dtype) # Added cast to avoid errors + + # Useless for our testing purposes + # mask = min_a < min_p + # a_idx = torch.cat( + # (mask.nonzero().view(-1) + cnt, (~mask).nonzero().view(-1)) + # ) # use p as anchor + # p_idx = torch.cat( + # (mask.nonzero().view(-1), (~mask).nonzero().view(-1) + cnt) + # ) # use a as anchor + # n_idx = torch.cat((min_a_idx[mask], min_p_idx[~mask] + cnt)) + # return a_idx, p_idx, n_idx + + +class OriginalImplementationDynamicSoftMarginLoss(nn.Module): + def __init__(self, is_binary=False, momentum=0.01, max_dist=None, nbins=512): + """ + is_binary: true if learning binary descriptor + momentum: weight assigned to the histogram computed from the current batch + max_dist: maximum possible distance in the feature space + nbins: number of bins to discretize the PDF + """ + super(OriginalImplementationDynamicSoftMarginLoss, self).__init__() + self._is_binary = is_binary + + if max_dist is None: + # max_dist = 256 if self._is_binary else 2.0 + max_dist = 2.0 + + self._momentum = momentum + self._max_val = max_dist + self._min_val = -max_dist + self.register_buffer("histogram", torch.ones(nbins)) + + self._stats_initialized = False + self.current_step = None + + def _compute_distances(self, x, labels=None): + # Useless for testing purposes + # if self._is_binary: + # return self._compute_hamming_distances(x) + # else: + return self._compute_l2_distances(x, labels=labels) + + # Formatted to test with and without labels + def _compute_l2_distances(self, x, labels=None): + if labels is None: + cnt = x.size(0) // 2 + a = x[:cnt, :] + p = x[cnt:, :] + dmat = compute_distance_matrix_unit_l2(a, p) + return find_hard_negatives(dmat, output_index=False, empirical_thresh=0.008) + else: + dmat = compute_distance_matrix_unit_l2(x, x) + dmat.fill_diagonal_(0) # Put distance to itself to 0 + anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets( + None, labels, labels, t_per_anchor="all" + ) + return dmat[anchor_idx, positive_idx], dmat[anchor_idx, negative_idx] + + # We do not use binary descriptors + # def _compute_hamming_distances(self, x): + # cnt = x.size(0) // 2 + # ndims = x.size(1) + # a = x[:cnt, :] + # p = x[cnt:, :] + + # dmat = compute_distance_matrix_hamming( + # (a > 0).float() * 2.0 - 1.0, (p > 0).float() * 2.0 - 1.0 + # ) + # a_idx, p_idx, n_idx = find_hard_negatives( + # dmat, output_index=True, empirical_thresh=2 + # ) + + # # differentiable Hamming distance + # a = x[a_idx, :] + # p = x[p_idx, :] + # n = x[n_idx, :] + + # pos_dist = (1.0 - a * p).sum(dim=1) / ndims + # neg_dist = (1.0 - a * n).sum(dim=1) / ndims + + # # non-differentiable Hamming distance + # a_b = (a > 0).float() * 2.0 - 1.0 + # p_b = (p > 0).float() * 2.0 - 1.0 + # n_b = (n > 0).float() * 2.0 - 1.0 + + # pos_dist_b = (1.0 - a_b * p_b).sum(dim=1) / ndims + # neg_dist_b = (1.0 - a_b * n_b).sum(dim=1) / ndims + + # return pos_dist, neg_dist, pos_dist_b, neg_dist_b + + def _compute_histogram(self, x, momentum): + """ + update the histogram using the current batch + """ + num_bins = self.histogram.size(0) + x_detached = x.detach() + self.bin_width = (self._max_val - self._min_val) / num_bins # Adjusted formula + lo = torch.floor( + (x_detached - self._min_val) / self.bin_width + ).long() # Add cast to avoid errors + hi = (lo + 1).clamp(min=0, max=num_bins - 1) + hist = x.new_zeros(num_bins) + alpha = ( + 1.0 + - (x_detached - self._min_val - lo.float() * self.bin_width) + / self.bin_width + ).to( + dtype=hist.dtype + ) # Added cast to avoid errors + hist.index_add_(0, lo, alpha) + hist.index_add_(0, hi, 1.0 - alpha) + hist = hist / (hist.sum() + 1e-6) + self.histogram = c_f.to_device( + self.histogram, tensor=hist, dtype=hist.dtype + ) # Line added to avoid errors + self.histogram = (1.0 - momentum) * self.histogram + momentum * hist + + def _compute_stats(self, pos_dist, neg_dist): + hist_val = pos_dist - neg_dist + if self._stats_initialized: + self._compute_histogram(hist_val, self._momentum) + else: + self._compute_histogram(hist_val, 1.0) + self._stats_initialized = True + + def forward(self, x, labels=None): + distances = self._compute_distances(x, labels=labels) + if not self._is_binary: + pos_dist, neg_dist = distances + self._compute_stats(pos_dist, neg_dist) + hist_var = pos_dist - neg_dist + else: + pos_dist, neg_dist, pos_dist_b, neg_dist_b = distances + self._compute_stats(pos_dist_b, neg_dist_b) + hist_var = pos_dist_b - neg_dist_b + + PDF = self.histogram / self.histogram.sum() + CDF = PDF.cumsum(0) + + # lookup weight from the CDF + bin_idx = torch.floor((hist_var - self._min_val) / self.bin_width).long() + weight = CDF[bin_idx] + + # Changed to an equivalent version for making same computation as in dynamic_soft_margin_loss.py + # loss = -(neg_dist * weight).mean() + (pos_dist * weight).mean() + loss = (hist_var * weight).mean() + return loss.to(device=x.device, dtype=x.dtype) # Added cast to avoid errors + + +import unittest + +from pytorch_metric_learning.losses import DynamicSoftMarginLoss + +from .. import TEST_DEVICE, TEST_DTYPES +from ..zzz_testing_utils.testing_utils import angle_to_coord + + +class TestDynamicSoftMarginLoss(unittest.TestCase): + def test_dynamic_soft_margin_loss_without_labels(self): + torch.manual_seed(21) + for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue + embeddings = torch.randn( + 5, + 32, + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + embeddings = torch.nn.functional.normalize(embeddings) + cnt = embeddings.size(0) // 2 + + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-3.0, + num_bins=10, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-3.0, + num_bins=20, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-3.0, + num_bins=30, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-2.0, + num_bins=10, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-2.0, + num_bins=20, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-2.0, + num_bins=30, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-1.0, + num_bins=10, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-1.0, + num_bins=20, + ) + self.helper( + embeddings[:cnt, :], + None, + dtype, + ref_emb=embeddings[cnt:, :], + min_val=-1.0, + num_bins=30, + ) + + def test_dynamic_soft_margin_loss_with_labels(self): + torch.manual_seed(21) + for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue + embeddings = torch.randn( + 5, + 32, + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + embeddings = torch.nn.functional.normalize(embeddings) + labels = torch.LongTensor([0, 0, 1, 1, 2]) + + self.helper(embeddings, labels, dtype, min_val=-3.0, num_bins=10) + self.helper(embeddings, labels, dtype, min_val=-3.0, num_bins=20) + self.helper(embeddings, labels, dtype, min_val=-3.0, num_bins=30) + self.helper(embeddings, labels, dtype, min_val=-2.0, num_bins=10) + self.helper(embeddings, labels, dtype, min_val=-2.0, num_bins=20) + self.helper(embeddings, labels, dtype, min_val=-2.0, num_bins=30) + self.helper(embeddings, labels, dtype, min_val=-1.0, num_bins=10) + self.helper(embeddings, labels, dtype, min_val=-1.0, num_bins=20) + self.helper(embeddings, labels, dtype, min_val=-1.0, num_bins=30) + + def helper( + self, + embeddings, + labels, + dtype, + ref_emb=None, + ref_labels=None, + min_val=-2.0, + num_bins=10, + ): + loss_func = DynamicSoftMarginLoss(min_val=min_val, num_bins=num_bins) + original_loss_func = OriginalImplementationDynamicSoftMarginLoss( + max_dist=-min_val, nbins=num_bins + ) + + loss = loss_func(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels) + if labels is None: + embeddings = torch.cat((embeddings, ref_emb)) + correct_loss = original_loss_func(embeddings, labels) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + def test_with_no_valid_triplets(self): + torch.manual_seed(21) + loss_func = DynamicSoftMarginLoss() + for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue + embedding_angles = [0, 20, 40, 60, 80] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([0, 1, 2, 3, 4]) + loss = loss_func(embeddings, labels) + self.assertEqual(loss, 0) + + def test_backward(self): + torch.manual_seed(21) + for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue + loss_func = DynamicSoftMarginLoss() + embedding_angles = [0, 20, 40, 60, 80] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([0, 0, 1, 1, 2]) + + loss = loss_func(embeddings, labels) + loss.backward() diff --git a/tests/losses/test_histogram_loss.py b/tests/losses/test_histogram_loss.py index aeeb54af..5658a397 100644 --- a/tests/losses/test_histogram_loss.py +++ b/tests/losses/test_histogram_loss.py @@ -106,7 +106,7 @@ def histogram(inds, size): ) histogram_pos_inds = torch.tril( torch.ones(histogram_pos_repeat.size()), -1 - ).byte() + ).bool() if self.cuda: histogram_pos_inds = histogram_pos_inds.cuda() histogram_pos_repeat[histogram_pos_inds] = 0 diff --git a/tests/losses/test_manifold_loss.py b/tests/losses/test_manifold_loss.py index 8d786a58..6dd8792c 100644 --- a/tests/losses/test_manifold_loss.py +++ b/tests/losses/test_manifold_loss.py @@ -151,7 +151,10 @@ def loss_incorrect_descriptors_dim(): class TestManifoldLoss(unittest.TestCase): def test_intrinsic_and_context_losses(self): + torch.manual_seed(24) for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue batch_size, embedding_size = 32, 128 n_proxies = 3 @@ -191,7 +194,10 @@ def test_intrinsic_and_context_losses(self): self.assertTrue(torch.isclose(original_loss, loss, rtol=rtol)) def test_with_original_implementation(self): + torch.manual_seed(24) for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue batch_size, embedding_size = 32, 128 n_proxies = 5 diff --git a/tests/losses/test_p2s_grad_loss.py b/tests/losses/test_p2s_grad_loss.py index 991f0354..7de8ab04 100644 --- a/tests/losses/test_p2s_grad_loss.py +++ b/tests/losses/test_p2s_grad_loss.py @@ -149,10 +149,13 @@ def forward(self, input_score, target): class TestP2SGradLoss(unittest.TestCase): def test_p2s_grad_loss_with_paper_formula(self): + torch.manual_seed(23) num_classes = 20 batch_size = 100 descriptors_dim = 128 for dtype in TEST_DTYPES: + if dtype == torch.float16: + continue embeddings = torch.randn( batch_size, descriptors_dim, @@ -196,6 +199,7 @@ def test_p2s_grad_loss_with_paper_formula(self): ) def test_p2s_grad_loss_with_trusted_implementation(self): + torch.manual_seed(23) num_classes = 20 batch_size = 100 descriptors_dim = 128 diff --git a/tests/losses/test_pnp_loss.py b/tests/losses/test_pnp_loss.py index 401cb6dc..6a1ad3ea 100644 --- a/tests/losses/test_pnp_loss.py +++ b/tests/losses/test_pnp_loss.py @@ -130,3 +130,11 @@ def test_pnp_loss(self): with self.assertRaises(ValueError): PNPLoss(b, alpha, anneal, "PNP") + + def test_negatives_that_have_no_positives(self): + loss_func = PNPLoss() + labels = torch.tensor([1, 1, 2], device=TEST_DEVICE) + for dtype in TEST_DTYPES: + embeddings = torch.randn(3, 32, dtype=dtype, device=TEST_DEVICE) + loss = loss_func(embeddings, labels) + self.assertTrue(not torch.isnan(loss)) diff --git a/tests/losses/test_ranked_list_loss.py b/tests/losses/test_ranked_list_loss.py new file mode 100644 index 00000000..66dc9492 --- /dev/null +++ b/tests/losses/test_ranked_list_loss.py @@ -0,0 +1,168 @@ +import unittest + +import torch + +from pytorch_metric_learning.losses import RankedListLoss + +from .. import TEST_DEVICE, TEST_DTYPES + + +class TestRankedListLoss(unittest.TestCase): + def test_ranked_list_loss_simpler(self): + torch.manual_seed(22) + batch_size = 32 + embedding_size = 64 + for dtype in TEST_DTYPES: + # test multiple times + for _ in range(2): + embeddings = torch.randn( + batch_size, + embedding_size, + requires_grad=True, + dtype=dtype, + ).to(TEST_DEVICE) + labels = torch.randint(0, 5, size=(batch_size,)) + + normalized_embeddings = torch.nn.functional.normalize( + embeddings, p=2, dim=1 + ) + n = len(embeddings) + for Tp, lam, margin in zip( + [0, 0.5, 3, -10], [0, 0.5, 0.7, 0.9], [0, 0.4, 0.8, 1.2] + ): + alpha = 1 - margin / 2 + Tn = Tp + loss_func = RankedListLoss( + margin=margin, Tn=Tn, imbalance=lam, alpha=alpha, Tp=Tp + ) + + L_RLL = torch.zeros( + n, + ).to(dtype=dtype) + for i in range(n): + w_p = torch.zeros( + n, + ).to(dtype=dtype) + w_n = torch.zeros( + n, + ).to(dtype=dtype) + L_P = torch.zeros( + n, + ).to(dtype=dtype) + L_N = torch.zeros( + n, + ).to(dtype=dtype) + for j in range(n): + if i == j: + continue + + d_ij = ( + torch.sum( + ( + normalized_embeddings[i, :] + - normalized_embeddings[j, :] + ) + ** 2 + ) + ** 0.5 + ) + if labels[j] == labels[i] and d_ij > alpha - margin: + w_p[j] = torch.exp(Tp * (d_ij - (alpha - margin))) + L_P[j] = d_ij - (alpha - margin) + elif labels[j] != labels[i] and d_ij < alpha: + w_n[j] = torch.exp(Tn * (alpha - d_ij)) + L_N[j] = alpha - d_ij + L_P = torch.sum(w_p * L_P) / torch.sum(w_p + 1e-5) + L_N = torch.sum(w_n * L_N) / torch.sum(w_n + 1e-5) + + L_RLL[i] = (1 - lam) * L_P + lam * L_N + correct_loss = torch.mean(L_RLL) + loss = loss_func(embeddings, labels) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + loss.backward() + + def test_ranked_list_loss(self): + torch.manual_seed(22) + batch_size = 32 + embedding_size = 64 + for dtype in TEST_DTYPES: + # test multiple times + for _ in range(2): + embeddings = torch.randn( + batch_size, + embedding_size, + requires_grad=True, + dtype=dtype, + ).to(TEST_DEVICE) + labels = torch.randint(0, 5, size=(batch_size,)) + + normalized_embeddings = torch.nn.functional.normalize( + embeddings, p=2, dim=1 + ) + n = len(embeddings) + for Tn, Tp, alpha, lam, margin in zip( + [0.3, 0.8, 2, -10], + [0, 0.5, 3, -10], + [0.3, 0.8, 1, 3], + [0, 0.5, 0.7, 0.9], + [0, 0.4, 0.8, 1.2], + ): + loss_func = RankedListLoss( + margin=margin, Tn=Tn, imbalance=lam, alpha=alpha, Tp=Tp + ) + + L_RLL = torch.zeros( + n, + ).to(dtype=dtype) + for i in range(n): + w_p = torch.zeros( + n, + ).to(dtype=dtype) + w_n = torch.zeros( + n, + ).to(dtype=dtype) + L_P = torch.zeros( + n, + ).to(dtype=dtype) + L_N = torch.zeros( + n, + ).to(dtype=dtype) + for j in range(n): + if i == j: + continue + + d_ij = ( + torch.sum( + ( + normalized_embeddings[i, :] + - normalized_embeddings[j, :] + ) + ** 2 + ) + ** 0.5 + ) + if labels[j] == labels[i] and d_ij > alpha - margin: + w_p[j] = torch.exp(Tp * (d_ij - (alpha - margin))) + L_P[j] = d_ij - (alpha - margin) + elif labels[j] != labels[i] and d_ij < alpha: + w_n[j] = torch.exp(Tn * (alpha - d_ij)) + L_N[j] = alpha - d_ij + L_P = torch.sum(w_p * L_P) / torch.sum(w_p + 1e-5) + L_N = torch.sum(w_n * L_N) / torch.sum(w_n + 1e-5) + + L_RLL[i] = (1 - lam) * L_P + lam * L_N + correct_loss = torch.mean(L_RLL) + loss = loss_func(embeddings, labels) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + loss.backward() + + def test_assertion_raises(self): + with self.assertRaises(AssertionError): + _ = RankedListLoss(margin=1, Tn=0, imbalance=-1) + _ = RankedListLoss(margin=1, Tn=0, imbalance=2) diff --git a/tests/reducers/test_setting_reducers.py b/tests/reducers/test_setting_reducers.py index 202145b5..8667bae0 100644 --- a/tests/reducers/test_setting_reducers.py +++ b/tests/reducers/test_setting_reducers.py @@ -18,7 +18,7 @@ def test_setting_reducers(self): ]: L = loss(reducer=reducer) if isinstance(L, TripletMarginLoss): - assert type(L.reducer) == type(reducer) + assert type(L.reducer) is type(reducer) else: for v in L.reducer.reducers.values(): - assert type(v) == type(reducer) + assert type(v) is type(reducer)