Skip to content

Commit ae8366c

Browse files
committed
clean and linter
1 parent d19ad7d commit ae8366c

File tree

7 files changed

+17
-40
lines changed

7 files changed

+17
-40
lines changed

deel/torchlip/modules/downsampling.py

-15
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@
2424
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
2525
# CRIAQ and ANITI - https://www.deel.ai/
2626
# =====================================================================================
27-
from typing import Tuple
28-
2927
import torch
3028

31-
from .. import functional as F
3229
from .module import LipschitzModule
3330

3431

@@ -43,15 +40,3 @@ def vanilla_export(self):
4340
else:
4441
return self
4542

46-
47-
# class InvertibleDownSampling(torch.nn.Module, LipschitzModule):
48-
# def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0):
49-
# torch.nn.Module.__init__(self)
50-
# LipschitzModule.__init__(self, k_coef_lip)
51-
# self.kernel_size = kernel_size
52-
53-
# def forward(self, input: torch.Tensor) -> torch.Tensor:
54-
# return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip
55-
56-
# def vanilla_export(self):
57-
# return self

deel/torchlip/modules/residual.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,20 @@
2424
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
2525
# CRIAQ and ANITI - https://www.deel.ai/
2626
# =====================================================================================
27-
from typing import Tuple
28-
from typing import Union
2927

3028
import torch
3129
from torch import nn
3230

3331

34-
class LipResidual(nn.Module):
32+
class LipResidual(nn.Module):
3533
"""
3634
This class is a 1-Lipschitz residual connection
3735
With a learnable parameter alpha that give a tradeoff
3836
between the x and the layer y=l(x)
39-
37+
4038
Args:
4139
"""
40+
4241
def __init__(self):
4342
super().__init__()
4443
self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)

deel/torchlip/modules/unconstrained.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,10 @@ def __init__(
156156
self.old_padding_mode = padding_mode
157157
if padding_mode.lower() == "symmetric":
158158
# symmetric padding of one pixel can be replaced by replicate
159-
if (isinstance(padding, int) and padding <= 1) or \
160-
(isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1):
161-
self.old_padding_mode = padding_mode = "replicate"
159+
if (isinstance(padding, int) and padding <= 1) or (
160+
isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1
161+
):
162+
self.old_padding_mode = padding_mode = "replicate"
162163
else:
163164
padding_mode = "zeros"
164165
padding = "valid"

deel/torchlip/modules/upsampling.py

-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
2525
# CRIAQ and ANITI - https://www.deel.ai/
2626
# =====================================================================================
27-
from typing import Tuple
28-
from typing import Union
2927

3028
import torch
3129

tests/test_unconstrained_layers.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,18 @@ def compare(x, x_ref, index_x=[], index_x_ref=[]):
4646
np.testing.assert_allclose(x_cropped, np.zeros(x_cropped.shape), 1e-2, 0)
4747
else:
4848
np.testing.assert_allclose(
49-
x_cropped-
50-
x_ref[
49+
x_cropped
50+
- x_ref[
5151
:, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
5252
][:, :, :: index_x_ref[2], :: index_x_ref[5]],
5353
np.zeros(x_cropped.shape),
5454
1e-2,
5555
0,
5656
)
57-
# np.testing.assert_allclose(
58-
# x_cropped,
59-
# x_ref[
60-
# :, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
61-
# ][:, :, :: index_x_ref[2], :: index_x_ref[5]],
62-
# 1e-2,
63-
# 0,
64-
# )
6557

6658

6759
@pytest.mark.parametrize(
68-
"padding_tested", ["circular", "constant", "symmetric", "reflect","replicate"]
60+
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate"]
6961
)
7062
@pytest.mark.parametrize(
7163
"input_shape, batch_size, kernel_size, filters",
@@ -167,7 +159,8 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
167159
reason="PadConv2d not available",
168160
)
169161
@pytest.mark.parametrize(
170-
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
162+
"padding_tested",
163+
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
171164
)
172165
@pytest.mark.parametrize(
173166
"input_shape, batch_size, kernel_size, filters",
@@ -240,7 +233,8 @@ def test_predict(padding_tested, input_shape, batch_size, kernel_size, filters):
240233
reason="PadConv2d not available",
241234
)
242235
@pytest.mark.parametrize(
243-
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
236+
"padding_tested",
237+
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
244238
)
245239
@pytest.mark.parametrize(
246240
"input_shape, batch_size, kernel_size, filters",

tests/test_updownsampling.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434

3535
def check_downsample(x, y, kernel_size):
36-
shape = uft.get_NCHW(x)
3736
index = 0
3837
for dx in range(kernel_size):
3938
for dy in range(kernel_size):

tests/utils_framework.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
"CategoricalHingeLoss",
128128
"process_labels_for_multi_gpu",
129129
"SpectralConv1d",
130+
"LipResidual",
130131
]
131132

132133

@@ -625,7 +626,7 @@ def is_supported_padding(padding):
625626
"reflect",
626627
"circular",
627628
"symmetric",
628-
'replicate'
629+
"replicate",
629630
] # "constant",
630631

631632

@@ -635,7 +636,7 @@ def pad_input(x, padding, kernel_size):
635636
kernel_size = [kernel_size, kernel_size]
636637
if padding.lower() in ["same", "valid"]:
637638
return x
638-
elif padding.lower() in ["constant", "reflect", "circular",'replicate']:
639+
elif padding.lower() in ["constant", "reflect", "circular", "replicate"]:
639640
p_vert, p_hor = kernel_size[0] // 2, kernel_size[1] // 2
640641
pad_sizes = [
641642
p_hor,

0 commit comments

Comments
 (0)