From 9c2d4a126ef33f992e246d4eeacbb0cd803c1a5f Mon Sep 17 00:00:00 2001 From: Franck Mamalet <49721198+franckma31@users.noreply.github.com> Date: Wed, 13 Nov 2024 01:17:41 +0100 Subject: [PATCH] replace Reshape layer by torch.nn.UnShuffle for tests --- deel/torchlip/modules/__init__.py | 1 - deel/torchlip/modules/module.py | 12 +----------- tests/test_compute_layer_sv.py | 2 +- tests/test_layers.py | 4 +--- tests/test_models.py | 4 +++- tests/utils_framework.py | 2 +- 6 files changed, 7 insertions(+), 18 deletions(-) diff --git a/deel/torchlip/modules/__init__.py b/deel/torchlip/modules/__init__.py index dca924b..af8406c 100644 --- a/deel/torchlip/modules/__init__.py +++ b/deel/torchlip/modules/__init__.py @@ -69,7 +69,6 @@ from .module import LipschitzModule from .module import Sequential from .module import vanilla_model -from .module import Reshape from .pooling import ScaledAdaptiveAvgPool2d from .pooling import ScaledAvgPool2d from .pooling import ScaledL2NormPool2d diff --git a/deel/torchlip/modules/module.py b/deel/torchlip/modules/module.py index ccc8f62..a951a61 100644 --- a/deel/torchlip/modules/module.py +++ b/deel/torchlip/modules/module.py @@ -40,7 +40,6 @@ import torch.nn as nn import torch.nn.utils.parametrize as parametrize from torch.nn import Sequential as TorchSequential -from torch import reshape def _is_supported_1lip_layer(layer): @@ -54,7 +53,7 @@ def _is_supported_1lip_layer(layer): torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Tanh, - Reshape, + torch.nn.Unflatten, ) if isinstance(layer, supported_1lip_layers): return True @@ -182,12 +181,3 @@ def vanilla_export(self): else: layers.append((name, copy.deepcopy(layer))) return TorchSequential(OrderedDict(layers)) - - -class Reshape(torch.nn.Module): - def __init__(self, target_shape): - super(Reshape, self).__init__() - self.target_shape = target_shape - - def forward(self, x): - return reshape(x, self.target_shape) diff --git a/tests/test_compute_layer_sv.py b/tests/test_compute_layer_sv.py index 9d37fca..9feca85 100644 --- a/tests/test_compute_layer_sv.py +++ b/tests/test_compute_layer_sv.py @@ -166,7 +166,7 @@ def train_compute_and_verifySV( logdir = os.path.join("logs", uft.LIP_LAYERS, "%s" % layer_type.__name__) os.makedirs(logdir, exist_ok=True) - callback_list = [] + callback_list = [] if "callbacks" in kwargs and (kwargs["callbacks"] is not None): callback_list = callback_list + kwargs["callbacks"] # train model diff --git a/tests/test_layers.py b/tests/test_layers.py index 107d0d8..5ca4008 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -193,9 +193,7 @@ def train_k_lip_model( logdir = os.path.join("logs", uft.LIP_LAYERS, "%s" % layer_type.__name__) os.makedirs(logdir, exist_ok=True) - callback_list = ( - [] - ) + callback_list = [] if kwargs["callbacks"] is not None: callback_list = callback_list + kwargs["callbacks"] # train model diff --git a/tests/test_models.py b/tests/test_models.py index d68eacc..33cf2ef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -263,7 +263,9 @@ def test_warning_unsupported_1Lip_layers(): ), # kl.Activation("relu"), uft.get_instance_framework(tSoftmax, {}), # kl.Softmax(), uft.get_instance_framework(Flatten, {}), # kl.Flatten(), - uft.get_instance_framework(tReshape, {"target_shape": (10,)}), # kl.Reshape(), + uft.get_instance_framework( + tReshape, {"dim": -1, "unflattened_size": (10,)} + ), # kl.Reshape(), uft.get_instance_framework( tMaxPool2d, {"kernel_size": (2, 2)} ), # kl.MaxPool2d(), diff --git a/tests/utils_framework.py b/tests/utils_framework.py index a97cf63..980fee8 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -23,6 +23,7 @@ from torch.nn import Conv2d as tConv2d from torch.nn import Conv2d as PadConv2d from torch.nn import Upsample as tUpSampling2d +from torch.nn import Unflatten as tReshape from torch import int32 as type_int32 from torch.nn.functional import pad from torch.nn import MultiMarginLoss as tMultiMarginLoss @@ -40,7 +41,6 @@ from deel.torchlip.modules import ScaledL2NormPool2d from deel.torchlip.modules import InvertibleDownSampling from deel.torchlip.modules import InvertibleUpSampling -from deel.torchlip.modules import Reshape as tReshape from deel.torchlip.utils import evaluate_lip_const from deel.torchlip.modules import (