From 115bd9583ec85178ddd265afa3be8ba13d26f83f Mon Sep 17 00:00:00 2001 From: Franck Mamalet Date: Fri, 11 Oct 2024 12:05:10 +0200 Subject: [PATCH] add support for molti gpu in binary losses --- deel/torchlip/functional.py | 116 ++++++++++++++++++++++++++++++++++ deel/torchlip/modules/loss.py | 38 ++++++++--- tests/utils_framework.py | 6 +- 3 files changed, 149 insertions(+), 11 deletions(-) diff --git a/deel/torchlip/functional.py b/deel/torchlip/functional.py index e2af036..b505dc6 100644 --- a/deel/torchlip/functional.py +++ b/deel/torchlip/functional.py @@ -327,6 +327,30 @@ def kr_loss( return torch.mean(weighted_input, dim=-1) +def kr_loss_multi_gpu(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + r"""Returns the element-wise KR loss when computing with a multi-GPU/TPU strategy. + + `target` and `input` can be either of shape (batch_size, 1) or + (batch_size, # classes). + + When using this loss function, the labels `target` must be pre-processed with the + `process_labels_for_multi_gpu()` function. + + Args: + input: Tensor of arbitrary shape. + target: pre-processed Tensor of the same shape as input. + + Returns: + The Wasserstein-1 loss between ``input`` and ``target``. + """ + target = target.view(input.shape).to(input.dtype) + # Since the information of batch size was included in `target` by + # `process_labels_for_multi_gpu()`, there is no need here to multiply by batch size. + # In binary case (`target` of shape (batch_size, 1)), `torch.mean(dim=-1)` + # behaves like `torch.squeeze()` to return element-wise loss of shape (batch_size, ). + return torch.mean(input * target, dim=-1) + + def neg_kr_loss( input: torch.Tensor, target: torch.Tensor, @@ -350,6 +374,33 @@ def neg_kr_loss( return -kr_loss(input, target) +def neg_kr_loss_multi_gpu( + input: torch.Tensor, + target: torch.Tensor, +) -> torch.Tensor: + """ + Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein + duality. + + `target` and `input` can be either of shape (batch_size, 1) or + (batch_size, # classes). + + When using this loss function, the labels `target` must be pre-processed with the + `process_labels_for_multi_gpu()` function. + + Args: + input: Tensor of arbitrary shape. + target: pre-processed Tensor of the same shape as input. + + Returns: + The negative Wasserstein-1 loss between ``input`` and ``target``. + + See Also: + :py:func:`kr_loss` + """ + return -kr_loss_multi_gpu(input, target) + + def hinge_margin_loss( input: torch.Tensor, target: torch.Tensor, @@ -414,6 +465,42 @@ def hkr_loss( ) +def hkr_loss_multi_gpu( + input: torch.Tensor, + target: torch.Tensor, + alpha: float, + min_margin: float = 1.0, +) -> torch.Tensor: + """ + Loss to estimate the wasserstein-1 distance with a hinge regularization using + Kantorovich-Rubinstein duality. + + Args: + input: Tensor of arbitrary shape. + target: Tensor of the same shape as input. + alpha: Regularization factor between the hinge and the KR loss. + min_margin: Minimal margin for the hinge loss. + true_values: tuple containing the two label for each predicted class. + + Returns: + The regularized Wasserstein-1 loss. + + See Also: + :py:func:`hinge_margin_loss` + :py:func:`kr_loss` + """ + assert alpha <= 1.0 + if alpha == 1.0: # alpha for hinge only + return hinge_margin_loss(input, target, min_margin) + if alpha == 0: + return -kr_loss_multi_gpu(input, target) + # true value: positive value should be the first to be coherent with the + # hinge loss (positive y_pred) + return alpha * hinge_margin_loss(input, target, min_margin) - ( + 1 - alpha + ) * kr_loss_multi_gpu(input, target) + + def kr_multiclass_loss( input: torch.Tensor, target: torch.Tensor, @@ -506,3 +593,32 @@ def hkr_multiclass_loss( return alpha * hinge_multiclass_loss(input, target, min_margin) - ( 1 - alpha ) * kr_multiclass_loss(input, target) + + +def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor: + """Process labels to be fed to any loss based on KR estimation with a multi-GPU/TPU + strategy. + + When using a multi-GPU/TPU strategy, the flag `multi_gpu` in KR-based losses must be + set to True and the labels have to be pre-processed with this function. + + For binary classification, the labels should be of shape [batch_size, 1]. + For multiclass problems, the labels must be one-hot encoded (1 or 0) with shape + [batch_size, number of classes]. + + Args: + labels (torch.Tensor): tensor containing the labels + + Returns: + torch.Tensor: labels processed for KR-based losses with multi-GPU/TPU strategy. + """ + pos_labels = torch.where(labels > 0, 1.0, 0.0).to(labels.dtype) + mean_pos = torch.mean(pos_labels, dim=0) + # pos factor = batch_size/number of positive samples + pos_factor = torch.nan_to_num(1.0 / mean_pos) + # neg factor = batch_size/number of negative samples + neg_factor = -torch.nan_to_num(1.0 / (1.0 - mean_pos)) + + # Since element-wise KR terms are averaged by loss reduction later on, it is needed + # to multiply by batch_size here. + return torch.where(labels > 0, pos_factor, neg_factor) diff --git a/deel/torchlip/modules/loss.py b/deel/torchlip/modules/loss.py index 11775d7..9dad2ef 100644 --- a/deel/torchlip/modules/loss.py +++ b/deel/torchlip/modules/loss.py @@ -38,19 +38,26 @@ class KRLoss(torch.nn.Module): duality. """ - def __init__(self, reduction: str = "mean", true_values=None): + def __init__(self, multi_gpu=False, reduction: str = "mean", true_values=None): """ Args: - true_values: tuple containing the two label for each predicted class. + multi_gpu (bool): set to True when running on multi-GPU/TPU + reduction: passed to tf.keras.Loss constructor + true_values: depreciated. """ super().__init__() self.reduction = reduction + self.multi_gpu = multi_gpu + if multi_gpu: + self.kr_function = F.kr_loss_multi_gpu + else: + self.kr_function = F.kr_loss assert ( true_values is None ), "depreciated true_values should be None (use target>0)" def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - loss_batch = F.kr_loss(input, target) + loss_batch = self.kr_function(input, target) return F.apply_reduction(loss_batch, self.reduction) @@ -60,19 +67,26 @@ class NegKRLoss(torch.nn.Module): the Kantorovich-Rubinstein duality. """ - def __init__(self, reduction: str = "mean", true_values=None): + def __init__(self, multi_gpu=False, reduction: str = "mean", true_values=None): """ Args: - true_values: tuple containing the two label for each predicted class. + multi_gpu (bool): set to True when running on multi-GPU/TPU + reduction: passed to tf.keras.Loss constructor + true_values: depreciated. """ super().__init__() self.reduction = reduction + self.multi_gpu = multi_gpu + if multi_gpu: + self.kr_function = F.neg_kr_loss_multi_gpu + else: + self.kr_function = F.neg_kr_loss assert ( true_values is None ), "depreciated true_values should be None (use target>0)" def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - loss_batch = F.neg_kr_loss(input, target) + loss_batch = self.kr_function(input, target) return F.apply_reduction(loss_batch, self.reduction) @@ -105,6 +119,7 @@ def __init__( self, alpha: float, min_margin: float = 1.0, + multi_gpu=False, reduction: str = "mean", true_values=None, ): @@ -112,10 +127,17 @@ def __init__( Args: alpha: Regularization factor ([0,1]) between the hinge and the KR loss. min_margin: Minimal margin for the hinge loss. - true_values: tuple containing the two label for each predicted class. + multi_gpu (bool): set to True when running on multi-GPU/TPU + reduction: passed to tf.keras.Loss constructor + true_values: depreciated. """ super().__init__() self.reduction = reduction + self.multi_gpu = multi_gpu + if multi_gpu: + self.hkr_function = F.hkr_loss_multi_gpu + else: + self.hkr_function = F.hkr_loss if (alpha >= 0) and (alpha <= 1): self.alpha = alpha else: @@ -130,7 +152,7 @@ def __init__( ), "depreciated true_values should be None (use target>0)" def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - loss_batch = F.hkr_loss(input, target, self.alpha, self.min_margin) + loss_batch = self.hkr_function(input, target, self.alpha, self.min_margin) return F.apply_reduction(loss_batch, self.reduction) diff --git a/tests/utils_framework.py b/tests/utils_framework.py index 543dfa5..7955a1a 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -65,6 +65,7 @@ from deel.torchlip.modules import vanilla_model from deel.torchlip.functional import invertible_downsample from deel.torchlip.functional import invertible_upsample +from deel.torchlip.functional import process_labels_for_multi_gpu from deel.torchlip.utils.bjorck_norm import bjorck_norm, remove_bjorck_norm from deel.torchlip.utils.frobenius_norm import ( @@ -147,7 +148,6 @@ def __call__(self, **kwargs): TauCategoricalCrossentropyLoss = TauCrossEntropyLoss TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss -process_labels_for_multi_gpu = module_Unavailable_class CategoricalProvableRobustAccuracy = module_Unavailable_class BinaryProvableRobustAccuracy = module_Unavailable_class CategoricalProvableAvgRobustness = module_Unavailable_class @@ -224,13 +224,13 @@ def get_instance_withcheck( KRLoss: partial( get_instance_withcheck, dict_keys_replace={"name": None}, - list_keys_notimplemented=["multi_gpu"], + list_keys_notimplemented=[], ), HingeMarginLoss: partial(get_instance_withcheck, dict_keys_replace={"name": None}), HKRLoss: partial( get_instance_withcheck, dict_keys_replace={"name": None}, - list_keys_notimplemented=["multi_gpu"], + list_keys_notimplemented=[], ), HingeMulticlassLoss: partial( get_instance_withcheck,