Skip to content

Commit

Permalink
update docstring in losses
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 7, 2024
1 parent d7481a2 commit 9658d27
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 69 deletions.
38 changes: 20 additions & 18 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,7 @@ def apply_reduction(val: torch.Tensor, reduction: str) -> torch.Tensor:
return red(val)


def kr_loss(
input: torch.Tensor, target: torch.Tensor, multi_gpu=False, true_values=None
) -> torch.Tensor:
def kr_loss(input: torch.Tensor, target: torch.Tensor, multi_gpu=False) -> torch.Tensor:
r"""
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
as per
Expand All @@ -300,12 +298,19 @@ def kr_loss(
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
two possible labels as specific by ``true_values``.
two possible labels as specific by their sign.
`target` accepts label values in (0, 1), (-1, 1), or pre-processed with the
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
Using a multi-GPU/TPU strategy requires to set `multi_gpu` to True and to
pre-process the labels `target` with the
`deel.torchlip.functional.process_labels_for_multi_gpu()` function.
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
true_values: depreciated (target>0 is used)
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
Expand All @@ -316,9 +321,7 @@ def kr_loss(
return kr_loss_standard(input, target)


def kr_loss_standard(
input: torch.Tensor, target: torch.Tensor, true_values=None
) -> torch.Tensor:
def kr_loss_standard(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Loss to estimate the Wasserstein-1 distance using Kantorovich-Rubinstein duality,
as per
Expand All @@ -329,12 +332,13 @@ def kr_loss_standard(
- \underset{\mathbf{x}\sim{}\nu}{\mathbb{E}}[f(\mathbf{x})]
where :math:`\mu` and :math:`\nu` are the distributions corresponding to the
two possible labels as specific by ``true_values``.
two possible labels as specific by their sign.
`target` accepts label values in (0, 1), (-1, 1)
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
true_values: depreciated (target>0 is used)
Returns:
The Wasserstein-1 loss between ``input`` and ``target``.
Expand Down Expand Up @@ -384,7 +388,6 @@ def neg_kr_loss(
input: torch.Tensor,
target: torch.Tensor,
multi_gpu=False,
true_values=None,
) -> torch.Tensor:
"""
Loss to estimate the negative wasserstein-1 distance using Kantorovich-Rubinstein
Expand All @@ -393,7 +396,7 @@ def neg_kr_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
true_values: depreciated (target>0 is used)
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The negative Wasserstein-1 loss between ``input`` and ``target``.
Expand Down Expand Up @@ -437,7 +440,6 @@ def hkr_loss(
alpha: float,
min_margin: float = 1.0,
multi_gpu=False,
true_values=None,
) -> torch.Tensor:
"""
Loss to estimate the wasserstein-1 distance with a hinge regularization using
Expand All @@ -446,9 +448,9 @@ def hkr_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
alpha: Regularization factor between the hinge and the KR loss.
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The regularized Wasserstein-1 loss.
Expand Down Expand Up @@ -478,7 +480,7 @@ def hinge_multiclass_loss(
"""
Loss to estimate the Hinge loss in a multiclass setup. It compute the
elementwise hinge term. Note that this formulation differs from the
one commonly found in tensorflow/pytorch (with marximise the difference
one commonly found in tensorflow/pytorch (with maximise the difference
between the two largest logits). This formulation is consistent with the
binary classification loss used in a multiclass fashion.
Expand Down Expand Up @@ -515,9 +517,9 @@ def hkr_multiclass_loss(
Args:
input: Tensor of arbitrary shape.
target: Tensor of the same shape as input.
alpha: Regularization factor between the hinge and the KR loss.
alpha: Regularization factor ([0,1]) between the hinge and the KR loss.
min_margin: Minimal margin for the hinge loss.
true_values: tuple containing the two label for each predicted class.
multi_gpu (bool): set to True when running on multi-GPU/TPU
Returns:
The regularized Wasserstein-1 loss.
Expand Down
Loading

0 comments on commit 9658d27

Please sign in to comment.