From c75092bf9311648c0d519a790070c1744f54ba77 Mon Sep 17 00:00:00 2001 From: Franck Mamalet Date: Mon, 15 Apr 2024 08:13:57 +0200 Subject: [PATCH] add support for softHKR loss: SoftHKRMulticlassLoss warning alpha\in[0.1] --- deel/torchlip/__init__.py | 1 + deel/torchlip/modules/__init__.py | 1 + deel/torchlip/modules/loss.py | 137 +++++++++++++++++++++++++++++- 3 files changed, 138 insertions(+), 1 deletion(-) diff --git a/deel/torchlip/__init__.py b/deel/torchlip/__init__.py index 72c0e52..eed670d 100644 --- a/deel/torchlip/__init__.py +++ b/deel/torchlip/__init__.py @@ -39,6 +39,7 @@ "GroupSort2", "HKRLoss", "HKRMulticlassLoss", + "SoftHKRMulticlassLoss", "HingeMarginLoss", "HingeMulticlassLoss", "InvertibleDownSampling", diff --git a/deel/torchlip/modules/__init__.py b/deel/torchlip/modules/__init__.py index 7f730d4..5efbb08 100644 --- a/deel/torchlip/modules/__init__.py +++ b/deel/torchlip/modules/__init__.py @@ -59,6 +59,7 @@ from .loss import HingeMulticlassLoss from .loss import HKRLoss from .loss import HKRMulticlassLoss +from .loss import SoftHKRMulticlassLoss from .loss import KRLoss from .loss import KRMulticlassLoss from .loss import NegKRLoss diff --git a/deel/torchlip/modules/loss.py b/deel/torchlip/modules/loss.py index 312be33..d155e13 100644 --- a/deel/torchlip/modules/loss.py +++ b/deel/torchlip/modules/loss.py @@ -27,7 +27,6 @@ from typing import Tuple import torch - from .. import functional as F @@ -161,3 +160,139 @@ def __init__( def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return F.hkr_multiclass_loss(input, target, self.alpha, self.min_margin) + + +class SoftHKRMulticlassLoss(torch.nn.Module): + def __init__( + self, + alpha=10.0, + min_margin=1.0, + alpha_mean=0.99, + temperature=1.0, + ): + """ + The multiclass version of HKR with softmax. This is done by computing + the HKR term over each class and averaging the results. + + Note that `y_true` could be either one-hot encoded, +/-1 values. + + + Args: + alpha (float): regularization factor (0 <= alpha <= 1), + 0 for KR only, 1 for hinge only + min_margin (float): margin to enforce. + temperature (float): factor for softmax temperature + (higher value increases the weight of the highest non y_true logits) + alpha_mean (float): geometric mean factor + one_hot_ytrue (bool): set to True when y_true are one hot encoded (0 or 1), + and False when y_true already signed bases (for instance +/-1) + reduction: passed to tf.keras.Loss constructor + name (str): passed to tf.keras.Loss constructor + + """ + assert (alpha >= 0) and (alpha <= 1), "alpha must in [0,1]" + self.alpha = torch.tensor(alpha, dtype=torch.float32) + self.min_margin_v = min_margin + self.alpha_mean = alpha_mean + + self.current_mean = torch.tensor((self.min_margin_v,), dtype=torch.float32) + """ constraint=lambda x: torch.clamp(x, 0.005, 1000), + name="current_mean", + )""" + + self.temperature = temperature * self.min_margin_v + if alpha == 1.0: # alpha = 1.0 => hinge only + self.fct = self.multiclass_hinge_soft + else: + if alpha == 0.0: # alpha = 0.0 => KR only + self.fct = self.kr_soft + else: + self.fct = self.hkr + + super(SoftHKRMulticlassLoss, self).__init__() + + def clamp_current_mean(self, x): + return torch.clamp(x, 0.005, 1000) + + def _update_mean(self, y_pred): + self.current_mean = self.current_mean.to(y_pred.device) + current_global_mean = torch.mean(torch.abs(y_pred)).to( + dtype=self.current_mean.dtype + ) + current_global_mean = ( + self.alpha_mean * self.current_mean + + (1 - self.alpha_mean) * current_global_mean + ) + self.current_mean = self.clamp_current_mean(current_global_mean) + total_mean = current_global_mean + total_mean = torch.clamp(total_mean, self.min_margin_v, 20000) + return total_mean + + def computeTemperatureSoftMax(self, y_true, y_pred): + total_mean = self._update_mean(y_pred) + current_temperature = ( + torch.clamp(self.temperature / total_mean, 0.005, 250) + .to(dtype=y_pred.dtype) + .detach() + ) + min_value = torch.tensor(torch.finfo(torch.float32).min, dtype=y_pred.dtype).to( + device=y_pred.device + ) + opposite_values = torch.where( + y_true > 0, min_value, current_temperature * y_pred + ) + F_soft_KR = torch.softmax(opposite_values, dim=-1) + one_value = torch.tensor(1.0, dtype=F_soft_KR.dtype).to(device=y_pred.device) + F_soft_KR = torch.where(y_true > 0, one_value, F_soft_KR) + return F_soft_KR + + def signed_y_pred(self, y_true, y_pred): + """Return for each item sign(y_true)*y_pred.""" + sign_y_true = torch.where(y_true > 0, 1, -1) # switch to +/-1 + sign_y_true = sign_y_true.to(dtype=y_pred.dtype) + return y_pred * sign_y_true + + def multiclass_hinge_preproc(self, signed_y_pred, min_margin): + """From multiclass_hinge(y_true, y_pred, min_margin) + simplified to use precalculated signed_y_pred""" + # compute the elementwise hinge term + hinge = torch.nn.functional.relu(min_margin / 2.0 - signed_y_pred) + return hinge + + def multiclass_hinge_soft_preproc(self, signed_y_pred, F_soft_KR): + hinge = self.multiclass_hinge_preproc(signed_y_pred, self.min_margin_v) + b = hinge * F_soft_KR + b = torch.sum(b, axis=-1) + return b + + def multiclass_hinge_soft(self, y_true, y_pred): + F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred) + signed_y_pred = self.signed_y_pred(y_true, y_pred) + return self.multiclass_hinge_soft_preproc(signed_y_pred, F_soft_KR) + + def kr_soft_preproc(self, signed_y_pred, F_soft_KR): + kr = -signed_y_pred + a = kr * F_soft_KR + a = torch.sum(a, axis=-1) + return a + + def kr_soft(self, y_true, y_pred): + F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred) + signed_y_pred = self.signed_y_pred(y_true, y_pred) + return self.kr_soft_preproc(signed_y_pred, F_soft_KR) + + def hkr(self, y_true, y_pred): + F_soft_KR = self.computeTemperatureSoftMax(y_true, y_pred) + signed_y_pred = self.signed_y_pred(y_true, y_pred) + + loss_softkr = self.kr_soft_preproc(signed_y_pred, F_soft_KR) + + loss_softhinge = self.multiclass_hinge_soft_preproc(signed_y_pred, F_soft_KR) + return (1 - self.alpha) * loss_softkr + self.alpha * loss_softhinge + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if not (isinstance(input, torch.Tensor)): # required for dtype.max + input = torch.Tensor(input, dtype=input.dtype) + if not (isinstance(target, torch.Tensor)): + target = torch.Tensor(target, dtype=input.dtype) + return self.fct(target, input)