Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytest unit tests, new losses support, and normalization enhancement #22

Merged
merged 50 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c75092b
add support for softHKR loss: SoftHKRMulticlassLoss warning alpha\in[…
Apr 15, 2024
432f706
add reduce mean for softHKR loss
Apr 15, 2024
4583add
detach to avoid graph expansion
Jun 5, 2024
f96fd0c
test uniformization using pytest
Oct 8, 2024
ae90761
update using uft
Oct 8, 2024
4325e43
update using uft
Oct 8, 2024
554e499
linter
Oct 8, 2024
37a54f5
flake8 errors
Oct 8, 2024
8a05fc7
flake8 errors test_metrics
Oct 8, 2024
016b9e1
seed on test normalizers
Oct 9, 2024
64aa108
linter on test_layers
Oct 9, 2024
7918a2b
linter on test_layers
Oct 9, 2024
f1688dc
avoid F401 linter error
Oct 9, 2024
576de52
test normalizer all close value
Oct 9, 2024
496fe04
F401 linter
Oct 9, 2024
929bed4
update kr_loss for nan support, and modify hkr losses to use 0<= alph…
Jul 1, 2024
a361056
add warning when alpha > 1. in SoftHKRMulticlassLoss
Oct 9, 2024
4007e77
update losses to support any target (target>0 for true value), alpha …
Oct 9, 2024
fbedaac
add vanilla export to InvertibleUpsampling class
Oct 9, 2024
f4c38c2
linter loss.py
Oct 9, 2024
20272e2
add MultiMarginLoss test based on pytorch implementation (warning no …
Oct 9, 2024
15f8ffe
add support for reduction in binary losses
Oct 9, 2024
2702d20
add support for reduction in multiclass losses
Oct 9, 2024
8920cb2
add support for reduction in softHKR loss
Oct 9, 2024
d68aee1
add supported lipschitz layer test in Sequential
Oct 11, 2024
c24ccf8
add support for Tau Cross ENtropy loasses and tests
Oct 11, 2024
115bd95
add support for molti gpu in binary losses
Oct 11, 2024
efc8ca9
linters
Oct 11, 2024
7c5cfc4
add support multi_gpu for all KR losses
Oct 11, 2024
0a9b9a4
linters
Oct 11, 2024
3c6bece
paranthesis tricks bjork
Jul 1, 2024
dbc5c1a
switch from n_iter to eps stopping criteria for Bjorck and spectral norm
Oct 14, 2024
a83eb67
switch from n_iter to eps stopping criteria + add spectral normaizati…
Oct 14, 2024
130e40a
linters
Oct 14, 2024
96333bd
update setup.cfg to move to python 39,310,311 and pt{1.10.2,1.13.1,2.…
Oct 14, 2024
4c8d32b
linters corrections
Oct 14, 2024
9a21dcc
linters corrections
Oct 14, 2024
13cdb98
Update python-lints.yml
franckma31 Oct 14, 2024
5e0d3ba
Update python-tests.yml
franckma31 Oct 14, 2024
a2729a5
update workflow github
Oct 14, 2024
9696ab4
update workflow github
Oct 14, 2024
2cfa164
update requirements and pytorch version
Oct 14, 2024
d7481a2
update version with file + cleaning
Oct 14, 2024
9658d27
update docstring in losses
franckma31 Nov 7, 2024
26fd0e3
clean test_condense
franckma31 Nov 7, 2024
e83741e
remove warning on eps_spectral
franckma31 Nov 7, 2024
f5acef3
remove reference to Keras in test files
franckma31 Nov 7, 2024
9c2d4a1
replace Reshape layer by torch.nn.UnShuffle for tests
franckma31 Nov 13, 2024
014764d
updated atol testing value due to numerical imprecision on random ini…
franckma31 Nov 13, 2024
6d7d3a1
update docs notebooks with new losses
franckma31 Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import abc
import copy
import logging
import warnings
import math
from collections import OrderedDict
from typing import Any
Expand All @@ -40,8 +41,27 @@
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.nn import Sequential as TorchSequential
from torch import reshape

logger = logging.getLogger("deel.torchlip")

def _is_supported_1lip_layer(layer):
"""Return True if the Keras layer is 1-Lipschitz. Note that in some cases, the layer
franckma31 marked this conversation as resolved.
Show resolved Hide resolved
is 1-Lipschitz for specific set of parameters.
"""
supported_1lip_layers = (
torch.nn.Softmax,
torch.nn.Flatten,
torch.nn.Identity,
torch.nn.ReLU,
torch.nn.Sigmoid,
torch.nn.Tanh,
Reshape,
franckma31 marked this conversation as resolved.
Show resolved Hide resolved
)
if isinstance(layer, supported_1lip_layers):
return True
elif isinstance(layer, torch.nn.MaxPool2d):
return layer.kernel_size <= layer.stride
return False


def vanilla_model(model: nn.Module):
Expand Down Expand Up @@ -133,13 +153,13 @@ def __init__(

# Force the Lipschitz coefficient:
n_layers = np.sum(
(isinstance(layer, LipschitzModule) for layer in self.children())
[isinstance(layer, LipschitzModule) for layer in self.children()]
)
for module in self.children():
if isinstance(module, LipschitzModule):
module._coefficient_lip = math.pow(k_coef_lip, 1 / n_layers)
else:
logger.warning(
elif _is_supported_1lip_layer(module) is not True:
warnings.warn(
"Sequential model contains a layer which is not a Lipschitz layer: {}".format( # noqa: E501
module
)
Expand All @@ -163,3 +183,12 @@ def vanilla_export(self):
else:
layers.append((name, copy.deepcopy(layer)))
return TorchSequential(OrderedDict(layers))


class Reshape(torch.nn.Module):
def __init__(self, target_shape):
super(Reshape, self).__init__()
self.target_shape = target_shape

def forward(self, x):
return reshape(x, self.target_shape)
franckma31 marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_warning_unsupported_1Lip_layers():
# Check that unsupported layers raise a warning
unsupported_layers = [
uft.get_instance_framework(
tMaxPool2d, {"pool_size": 3, "strides": 2}
tMaxPool2d, {"kernel_size": 3, "stride": 2}
), # kl.MaxPool2d(),
uft.get_instance_framework(tAdd, {}), # kl.Add(),
uft.get_instance_framework(tConcatenate, {}), # kl.Concatenate(),
Expand All @@ -309,8 +309,8 @@ def test_warning_unsupported_1Lip_layers():
# unsupported_layers.append(kl.Activation("gelu"))

for lay in unsupported_layers:
with pytest.warns(Warning):
if lay is not None:
if lay is not None:
with pytest.warns(Warning):
_ = uft.generate_k_lip_model(
Sequential,
{"layers": [lay]},
Expand Down