Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytest unit tests, new losses support, and normalization enhancement #22

Merged
merged 50 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c75092b
add support for softHKR loss: SoftHKRMulticlassLoss warning alpha\in[…
Apr 15, 2024
432f706
add reduce mean for softHKR loss
Apr 15, 2024
4583add
detach to avoid graph expansion
Jun 5, 2024
f96fd0c
test uniformization using pytest
Oct 8, 2024
ae90761
update using uft
Oct 8, 2024
4325e43
update using uft
Oct 8, 2024
554e499
linter
Oct 8, 2024
37a54f5
flake8 errors
Oct 8, 2024
8a05fc7
flake8 errors test_metrics
Oct 8, 2024
016b9e1
seed on test normalizers
Oct 9, 2024
64aa108
linter on test_layers
Oct 9, 2024
7918a2b
linter on test_layers
Oct 9, 2024
f1688dc
avoid F401 linter error
Oct 9, 2024
576de52
test normalizer all close value
Oct 9, 2024
496fe04
F401 linter
Oct 9, 2024
929bed4
update kr_loss for nan support, and modify hkr losses to use 0<= alph…
Jul 1, 2024
a361056
add warning when alpha > 1. in SoftHKRMulticlassLoss
Oct 9, 2024
4007e77
update losses to support any target (target>0 for true value), alpha …
Oct 9, 2024
fbedaac
add vanilla export to InvertibleUpsampling class
Oct 9, 2024
f4c38c2
linter loss.py
Oct 9, 2024
20272e2
add MultiMarginLoss test based on pytorch implementation (warning no …
Oct 9, 2024
15f8ffe
add support for reduction in binary losses
Oct 9, 2024
2702d20
add support for reduction in multiclass losses
Oct 9, 2024
8920cb2
add support for reduction in softHKR loss
Oct 9, 2024
d68aee1
add supported lipschitz layer test in Sequential
Oct 11, 2024
c24ccf8
add support for Tau Cross ENtropy loasses and tests
Oct 11, 2024
115bd95
add support for molti gpu in binary losses
Oct 11, 2024
efc8ca9
linters
Oct 11, 2024
7c5cfc4
add support multi_gpu for all KR losses
Oct 11, 2024
0a9b9a4
linters
Oct 11, 2024
3c6bece
paranthesis tricks bjork
Jul 1, 2024
dbc5c1a
switch from n_iter to eps stopping criteria for Bjorck and spectral norm
Oct 14, 2024
a83eb67
switch from n_iter to eps stopping criteria + add spectral normaizati…
Oct 14, 2024
130e40a
linters
Oct 14, 2024
96333bd
update setup.cfg to move to python 39,310,311 and pt{1.10.2,1.13.1,2.…
Oct 14, 2024
4c8d32b
linters corrections
Oct 14, 2024
9a21dcc
linters corrections
Oct 14, 2024
13cdb98
Update python-lints.yml
franckma31 Oct 14, 2024
5e0d3ba
Update python-tests.yml
franckma31 Oct 14, 2024
a2729a5
update workflow github
Oct 14, 2024
9696ab4
update workflow github
Oct 14, 2024
2cfa164
update requirements and pytorch version
Oct 14, 2024
d7481a2
update version with file + cleaning
Oct 14, 2024
9658d27
update docstring in losses
franckma31 Nov 7, 2024
26fd0e3
clean test_condense
franckma31 Nov 7, 2024
e83741e
remove warning on eps_spectral
franckma31 Nov 7, 2024
f5acef3
remove reference to Keras in test files
franckma31 Nov 7, 2024
9c2d4a1
replace Reshape layer by torch.nn.UnShuffle for tests
franckma31 Nov 13, 2024
014764d
updated atol testing value due to numerical imprecision on random ini…
franckma31 Nov 13, 2024
6d7d3a1
update docs notebooks with new losses
franckma31 Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ def lipschitz_prelu(


# Losses
def apply_reduction(val: torch.Tensor, reduction: str) -> torch.Tensor:
red = getattr(torch, reduction, None)
if red is None:
return val
return red(val)


def kr_loss(
Expand Down Expand Up @@ -305,12 +310,19 @@ def kr_loss(
"""

target = target.view(input.shape)
pos_target = (target > 0).to(input.dtype)
mean_pos = torch.mean(pos_target, dim=0)
# pos factor = batch_size/number of positive samples
pos_factor = torch.nan_to_num(1.0 / mean_pos)
# neg factor = batch_size/number of negative samples
neg_factor = -torch.nan_to_num(1.0 / (1.0 - mean_pos))

c1 = torch.mean(input[target > 0])
c1 = torch.nan_to_num(c1)
c2 = torch.mean(input[target <= 0])
c2 = torch.nan_to_num(c2)
return c1 - c2
weighted_input = torch.where(target > 0, pos_factor, neg_factor) * input
# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
# In binary case (`y_true` of shape (batch_size, 1)), `tf.reduce_mean(axis=-1)`
# behaves like `tf.squeeze()` to return element-wise loss of shape (batch_size, ).
return torch.mean(weighted_input, dim=-1)


def neg_kr_loss(
Expand Down Expand Up @@ -359,8 +371,8 @@ def hinge_margin_loss(
"""
target = target.view(input.shape)
sign_target = torch.where(target > 0, 1.0, -1.0).to(input.dtype)

return torch.mean(F.relu(min_margin / 2.0 - sign_target * input))
hinge = F.relu(min_margin / 2.0 - sign_target * input)
return torch.mean(hinge, dim=-1)


def hkr_loss(
Expand Down Expand Up @@ -395,8 +407,9 @@ def hkr_loss(
return -kr_loss(input, target)
# true value: positive value should be the first to be coherent with the
# hinge loss (positive y_pred)
return alpha * hinge_margin_loss(input, target, min_margin)
-(1 - alpha) * kr_loss(input, target)
return alpha * hinge_margin_loss(input, target, min_margin) - (1 - alpha) * kr_loss(
input, target
)


def kr_multiclass_loss(
Expand Down
23 changes: 16 additions & 7 deletions deel/torchlip/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ class KRLoss(torch.nn.Module):
duality.
"""

def __init__(self, true_values=None):
def __init__(self, reduction: str = "mean", true_values=None):
"""
Args:
true_values: tuple containing the two label for each predicted class.
franckma31 marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()
self.reduction = reduction
assert (
true_values is None
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.kr_loss(input, target)
loss_batch = F.kr_loss(input, target)
return F.apply_reduction(loss_batch, self.reduction)


class NegKRLoss(torch.nn.Module):
Expand All @@ -55,35 +57,39 @@ class NegKRLoss(torch.nn.Module):
the Kantorovich-Rubinstein duality.
"""

def __init__(self, true_values=None):
def __init__(self, reduction: str = "mean", true_values=None):
"""
Args:
true_values: tuple containing the two label for each predicted class.
"""
super().__init__()
self.reduction = reduction
assert (
true_values is None
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.neg_kr_loss(input, target)
loss_batch = F.neg_kr_loss(input, target)
return F.apply_reduction(loss_batch, self.reduction)


class HingeMarginLoss(torch.nn.Module):
"""
Hinge margin loss.
"""

def __init__(self, min_margin: float = 1.0):
def __init__(self, min_margin: float = 1.0, reduction: str = "mean"):
"""
Args:
min_margin: The minimal margin to enforce.
"""
super().__init__()
self.reduction = reduction
self.min_margin = min_margin

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.hinge_margin_loss(input, target, self.min_margin)
loss_batch = F.hinge_margin_loss(input, target, self.min_margin)
return F.apply_reduction(loss_batch, self.reduction)


class HKRLoss(torch.nn.Module):
Expand All @@ -96,6 +102,7 @@ def __init__(
self,
alpha: float,
min_margin: float = 1.0,
reduction: str = "mean",
true_values=None,
):
"""
Expand All @@ -105,6 +112,7 @@ def __init__(
true_values: tuple containing the two label for each predicted class.
"""
super().__init__()
self.reduction = reduction
if (alpha >= 0) and (alpha <= 1):
self.alpha = alpha
else:
Expand All @@ -119,7 +127,8 @@ def __init__(
), "depreciated true_values should be None (use target>0)"

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.hkr_loss(input, target, self.alpha, self.min_margin)
loss_batch = F.hkr_loss(input, target, self.alpha, self.min_margin)
return F.apply_reduction(loss_batch, self.reduction)


class KRMulticlassLoss(torch.nn.Module):
Expand Down
11 changes: 6 additions & 5 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def test_softhkrmulticlass_loss():
},
binary_data(y_true1b),
binary_data(y_pred1b),
[0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55],
np.float64([0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55])
* uft.scaleAlpha(2.5),
1e-7,
),
(
Expand Down Expand Up @@ -545,7 +546,7 @@ def test_no_reduction_loss_generic(
expected_loss,
rtol=rtol,
atol=5e-6,
err_msg=f"Loss {loss.name} failed",
err_msg=f"Loss {loss} failed",
)


Expand Down Expand Up @@ -749,7 +750,7 @@ def test_minibatches_binary_loss_generic(
expected_loss,
rtol=rtol,
atol=5e-6,
err_msg=f"Loss {loss.name} failed",
err_msg=f"Loss {loss} failed",
)
loss_val_minibatches = 0
for i in range(len(segments) - 1):
Expand All @@ -764,7 +765,7 @@ def test_minibatches_binary_loss_generic(
loss_val_minibatches,
rtol=rtol, # 5e-6,
atol=5e-6,
err_msg=f"Loss {loss.name} failed for hardcoded mini-batches",
err_msg=f"Loss {loss} failed for hardcoded mini-batches",
)


Expand Down Expand Up @@ -870,5 +871,5 @@ def test_multilabel_loss_generic(loss_instance, loss_params, rtol):
mean_loss_vals,
rtol=rtol, # 5e-6,
atol=1e-4,
err_msg=f"Loss {loss.name} failed",
err_msg=f"Loss {loss} failed",
)
14 changes: 10 additions & 4 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,17 @@ def get_instance_withcheck(
ScaledAvgPool2d: partial(
get_instance_withreplacement, dict_keys_replace={"data_format": None}
),
KRLoss: partial(get_instance_withcheck, list_keys_notimplemented=["reduction"]),
HingeMarginLoss: partial(
get_instance_withcheck, list_keys_notimplemented=["reduction"]
KRLoss: partial(
get_instance_withcheck,
dict_keys_replace={"name": None},
list_keys_notimplemented=["multi_gpu"],
),
HingeMarginLoss: partial(get_instance_withcheck, dict_keys_replace={"name": None}),
HKRLoss: partial(
get_instance_withcheck,
dict_keys_replace={"name": None},
list_keys_notimplemented=["multi_gpu"],
),
HKRLoss: partial(get_instance_withcheck, list_keys_notimplemented=["reduction"]),
HingeMulticlassLoss: partial(
get_instance_withcheck, list_keys_notimplemented=["reduction"]
),
Expand Down