Skip to content

Commit

Permalink
add padconv support + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Oct 28, 2024
1 parent 08f82de commit 852baa5
Show file tree
Hide file tree
Showing 7 changed files with 583 additions and 243 deletions.
39 changes: 39 additions & 0 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,42 @@ def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor:
# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
return torch.where(labels > 0, pos_factor, neg_factor)


class SymmetricPad(torch.nn.Module):
"""
Pads a 2D tensor symmetrically.
Args:
pad (tuple): A tuple (pad_left, pad_right, pad_top, pad_bottom) specifying
the number of pixels to pad on each side. (or single int if
common padding).
onedim: False for conv2d, True for conv1d.
"""

def __init__(self, pad, onedim=False):
super().__init__()
self.onedim = onedim
num_dim = 2 if onedim else 4
if isinstance(pad, int):
self.pad = (pad,) * num_dim
else:
self.pad = torch.nn.modules.utils._reverse_repeat_tuple(pad, 2)
assert len(self.pad) == num_dim, f"Pad must be a tuple of {num_dim} integers"

def forward(self, x):

# Horizontal padding
left = x[:, ..., : self.pad[0]].flip(dims=[-1])
right = x[:, ..., -self.pad[1] :].flip(dims=[-1])
x = torch.cat([left, x, right], dim=-1)
if self.onedim:
return x
# Vertical padding
top = x[:, :, : self.pad[2], :].flip(dims=[-2])
bottom = x[:, :, -self.pad[3] :, :].flip(dims=[-2])
x = torch.cat([top, x, bottom], dim=-2)

return x
2 changes: 2 additions & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,5 @@
from .upsampling import InvertibleUpSampling
from .normalization import LayerCentering
from .normalization import BatchCentering
from .unconstrained import PadConv2d
from .unconstrained import PadConv1d
68 changes: 17 additions & 51 deletions deel/torchlip/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from ..normalizers import DEFAULT_EPS_SPECTRAL
from ..utils import frobenius_norm
from ..utils import lconv_norm
from .unconstrained import PadConv1d, PadConv2d
from .module import LipschitzModule


class SpectralConv1d(torch.nn.Conv1d, LipschitzModule):
class SpectralConv1d(PadConv1d, LipschitzModule):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -89,14 +90,16 @@ def __init__(
# if padding_mode != "same":
# raise RuntimeError("NormalizedConv only support padding='same'")

torch.nn.Conv1d.__init__(
PadConv1d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)
LipschitzModule.__init__(self, k_coef_lip)
Expand All @@ -115,24 +118,10 @@ def __init__(
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv1d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer
return PadConv1d.vanilla_export(self)


class SpectralConv2d(torch.nn.Conv2d, LipschitzModule):
class SpectralConv2d(PadConv2d, LipschitzModule):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -185,14 +174,16 @@ def __init__(
# if padding_mode != "same":
# raise RuntimeError("NormalizedConv only support padding='same'")

torch.nn.Conv2d.__init__(
PadConv2d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)
LipschitzModule.__init__(self, k_coef_lip)
Expand All @@ -211,24 +202,10 @@ def __init__(
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer
return PadConv2d.vanilla_export(self)


class FrobeniusConv2d(torch.nn.Conv2d, LipschitzModule):
class FrobeniusConv2d(PadConv2d, LipschitzModule):
"""
Same as SpectralConv2d but in the case of a single output.
"""
Expand All @@ -251,14 +228,17 @@ def __init__(
# if padding_mode != "same":
# raise RuntimeError("NormalizedConv only support padding='same'")

torch.nn.Conv2d.__init__(
PadConv2d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode,
bias=bias,
dilation=dilation,
groups=groups,
)
LipschitzModule.__init__(self, k_coef_lip)

Expand All @@ -271,21 +251,7 @@ def __init__(
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer
return PadConv2d.vanilla_export(self)


class SpectralConvTranspose2d(torch.nn.ConvTranspose2d, LipschitzModule):
Expand Down
201 changes: 201 additions & 0 deletions deel/torchlip/modules/unconstrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================

from typing import Union
import torch
from torch.nn.common_types import _size_1_t, _size_2_t
from ..functional import SymmetricPad


class PadConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
):
"""
This class is a Conv1d Layer with additional padding modes
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution.
padding (int or tuple, optional): Zero-padding added to both sides of
the input.
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'``,``'symmetric'`` or ``'circular'``.
Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
Has to be one
groups (int, optional): Number of blocked connections from input
channels to output channels. Has to be one
bias (bool, optional): If ``True``, adds a learnable bias to the
output.
This documentation reuse the body of the original torch.nn.Conv1d doc.
"""

self.old_padding = padding
self.old_padding_mode = padding_mode
if padding_mode.lower() == "symmetric":
padding_mode = "zeros"
padding = "valid"

super(PadConv1d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)

if self.old_padding_mode.lower() == "symmetric":
self.pad = SymmetricPad(self.old_padding, onedim=True)
else:
self.pad = lambda x: x

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super(PadConv1d, self).forward(self.pad(input))

def vanilla_export(self):
if self.old_padding_mode.lower() == "symmetric":
next_layer_type = PadConv1d
else:
next_layer_type = torch.nn.Conv1d

layer = next_layer_type(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.old_padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.old_padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer


class PadConv2d(torch.nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
):
"""
This class is a Conv2d Layer with additional padding modes
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution.
padding (int or tuple, optional): Zero-padding added to both sides of
the input.
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'``,``'symmetric'`` or ``'circular'``.
Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
Has to be one
groups (int, optional): Number of blocked connections from input
channels to output channels. Has to be one
bias (bool, optional): If ``True``, adds a learnable bias to the
output.
This documentation reuse the body of the original torch.nn.Conv2D doc.
"""

self.old_padding = padding
self.old_padding_mode = padding_mode
if padding_mode.lower() == "symmetric":
padding_mode = "zeros"
padding = "valid"

super(PadConv2d, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)

if self.old_padding_mode.lower() == "symmetric":
self.pad = SymmetricPad(self.old_padding)
else:
self.pad = lambda x: x

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super(PadConv2d, self).forward(self.pad(input))

def vanilla_export(self):
if self.old_padding_mode.lower() == "symmetric":
next_layer_type = PadConv2d
else:
next_layer_type = torch.nn.Conv2d

layer = next_layer_type(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.old_padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.old_padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer
Loading

0 comments on commit 852baa5

Please sign in to comment.