Skip to content

Commit

Permalink
add support for molti gpu in binary losses
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 11, 2024
1 parent c24ccf8 commit 115bd95
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 11 deletions.
116 changes: 116 additions & 0 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,30 @@ def kr_loss(
return torch.mean(weighted_input, dim=-1)


def kr_loss_multi_gpu(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""Returns the element-wise KR loss when computing with a multi-GPU/TPU strategy.
`target` and `input` can be either of shape (batch_size, 1) or
(batch_size, # classes).
When using this loss function, the labels `target` must be pre-processed with the
`process_labels_for_multi_gpu()` function.
Args:
input: Tensor of arbitrary shape.
target: pre-processed Tensor of the same shape as input.
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
"""
target = target.view(input.shape).to(input.dtype)
# Since the information of batch size was included in `target` by
# `process_labels_for_multi_gpu()`, there is no need here to multiply by batch size.
# In binary case (`target` of shape (batch_size, 1)), `torch.mean(dim=-1)`
# behaves like `torch.squeeze()` to return element-wise loss of shape (batch_size, ).
return torch.mean(input * target, dim=-1)


def neg_kr_loss(
input: torch.Tensor,
target: torch.Tensor,
Expand All @@ -350,6 +374,33 @@ def neg_kr_loss(
return -kr_loss(input, target)


def neg_kr_loss_multi_gpu(
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""
Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein
duality.
`target` and `input` can be either of shape (batch_size, 1) or
(batch_size, # classes).
When using this loss function, the labels `target` must be pre-processed with the
`process_labels_for_multi_gpu()` function.
Args:
input: Tensor of arbitrary shape.
target: pre-processed Tensor of the same shape as input.
Returns:
The negative Wasserstein-1 loss between ``input`` and ``target``.
See Also:
:py:func:`kr_loss`
"""
return -kr_loss_multi_gpu(input, target)


def hinge_margin_loss(
input: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -414,6 +465,42 @@ def hkr_loss(
)


def hkr_loss_multi_gpu(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
min_margin: float = 1.0,
) -> torch.Tensor:
"""
Loss to estimate the wasserstein-1 distance with a hinge regularization using
Kantorovich-Rubinstein duality.
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
alpha: Regularization factor between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
Returns:
The regularized Wasserstein-1 loss.
See Also:
:py:func:`hinge_margin_loss`
:py:func:`kr_loss`
"""
assert alpha <= 1.0
if alpha == 1.0: # alpha for hinge only
return hinge_margin_loss(input, target, min_margin)
if alpha == 0:
return -kr_loss_multi_gpu(input, target)
# true value: positive value should be the first to be coherent with the
# hinge loss (positive y_pred)
return alpha * hinge_margin_loss(input, target, min_margin) - (
1 - alpha
) * kr_loss_multi_gpu(input, target)


def kr_multiclass_loss(
input: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -506,3 +593,32 @@ def hkr_multiclass_loss(
return alpha * hinge_multiclass_loss(input, target, min_margin) - (
1 - alpha
) * kr_multiclass_loss(input, target)


def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor:
"""Process labels to be fed to any loss based on KR estimation with a multi-GPU/TPU
strategy.
When using a multi-GPU/TPU strategy, the flag `multi_gpu` in KR-based losses must be
set to True and the labels have to be pre-processed with this function.
For binary classification, the labels should be of shape [batch_size, 1].
For multiclass problems, the labels must be one-hot encoded (1 or 0) with shape
[batch_size, number of classes].
Args:
labels (torch.Tensor): tensor containing the labels
Returns:
torch.Tensor: labels processed for KR-based losses with multi-GPU/TPU strategy.
"""
pos_labels = torch.where(labels > 0, 1.0, 0.0).to(labels.dtype)
mean_pos = torch.mean(pos_labels, dim=0)
# pos factor = batch_size/number of positive samples
pos_factor = torch.nan_to_num(1.0 / mean_pos)
# neg factor = batch_size/number of negative samples
neg_factor = -torch.nan_to_num(1.0 / (1.0 - mean_pos))

# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
return torch.where(labels > 0, pos_factor, neg_factor)
38 changes: 30 additions & 8 deletions deel/torchlip/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,26 @@ class KRLoss(torch.nn.Module):
duality.
"""

def __init__(self, reduction: str = "mean", true_values=None):
def __init__(self, multi_gpu=False, reduction: str = "mean", true_values=None):
"""
Args:
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
reduction: passed to tf.keras.Loss constructor
true_values: depreciated.
"""
super().__init__()
self.reduction = reduction
self.multi_gpu = multi_gpu
if multi_gpu:
self.kr_function = F.kr_loss_multi_gpu
else:
self.kr_function = F.kr_loss
assert (
true_values is None
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_batch = F.kr_loss(input, target)
loss_batch = self.kr_function(input, target)
return F.apply_reduction(loss_batch, self.reduction)


Expand All @@ -60,19 +67,26 @@ class NegKRLoss(torch.nn.Module):
the Kantorovich-Rubinstein duality.
"""

def __init__(self, reduction: str = "mean", true_values=None):
def __init__(self, multi_gpu=False, reduction: str = "mean", true_values=None):
"""
Args:
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
reduction: passed to tf.keras.Loss constructor
true_values: depreciated.
"""
super().__init__()
self.reduction = reduction
self.multi_gpu = multi_gpu
if multi_gpu:
self.kr_function = F.neg_kr_loss_multi_gpu
else:
self.kr_function = F.neg_kr_loss
assert (
true_values is None
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_batch = F.neg_kr_loss(input, target)
loss_batch = self.kr_function(input, target)
return F.apply_reduction(loss_batch, self.reduction)


Expand Down Expand Up @@ -105,17 +119,25 @@ def __init__(
self,
alpha: float,
min_margin: float = 1.0,
multi_gpu=False,
reduction: str = "mean",
true_values=None,
):
"""
Args:
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
reduction: passed to tf.keras.Loss constructor
true_values: depreciated.
"""
super().__init__()
self.reduction = reduction
self.multi_gpu = multi_gpu
if multi_gpu:
self.hkr_function = F.hkr_loss_multi_gpu
else:
self.hkr_function = F.hkr_loss
if (alpha >= 0) and (alpha <= 1):
self.alpha = alpha
else:
Expand All @@ -130,7 +152,7 @@ def __init__(
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_batch = F.hkr_loss(input, target, self.alpha, self.min_margin)
loss_batch = self.hkr_function(input, target, self.alpha, self.min_margin)
return F.apply_reduction(loss_batch, self.reduction)


Expand Down
6 changes: 3 additions & 3 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from deel.torchlip.modules import vanilla_model
from deel.torchlip.functional import invertible_downsample
from deel.torchlip.functional import invertible_upsample
from deel.torchlip.functional import process_labels_for_multi_gpu

from deel.torchlip.utils.bjorck_norm import bjorck_norm, remove_bjorck_norm
from deel.torchlip.utils.frobenius_norm import (
Expand Down Expand Up @@ -147,7 +148,6 @@ def __call__(self, **kwargs):
TauCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauSparseCategoricalCrossentropyLoss = TauCrossEntropyLoss
TauBinaryCrossentropyLoss = TauBCEWithLogitsLoss
process_labels_for_multi_gpu = module_Unavailable_class
CategoricalProvableRobustAccuracy = module_Unavailable_class
BinaryProvableRobustAccuracy = module_Unavailable_class
CategoricalProvableAvgRobustness = module_Unavailable_class
Expand Down Expand Up @@ -224,13 +224,13 @@ def get_instance_withcheck(
KRLoss: partial(
get_instance_withcheck,
dict_keys_replace={"name": None},
list_keys_notimplemented=["multi_gpu"],
list_keys_notimplemented=[],
),
HingeMarginLoss: partial(get_instance_withcheck, dict_keys_replace={"name": None}),
HKRLoss: partial(
get_instance_withcheck,
dict_keys_replace={"name": None},
list_keys_notimplemented=["multi_gpu"],
list_keys_notimplemented=[],
),
HingeMulticlassLoss: partial(
get_instance_withcheck,
Expand Down

0 comments on commit 115bd95

Please sign in to comment.