Skip to content

Commit

Permalink
Merge pull request #23 from deel-ai/feat/new_layers
Browse files Browse the repository at this point in the history
Add several new layers to be able to support Resnet like architectures
  • Loading branch information
thib-s authored Jan 14, 2025
2 parents 4360edf + 76a6e69 commit b23e090
Show file tree
Hide file tree
Showing 39 changed files with 4,058 additions and 1,006 deletions.
71 changes: 59 additions & 12 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,30 +212,38 @@ def max_min(input: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
return torch.cat((F.relu(input), F.relu(-input)), dim=dim)


def group_sort(input: torch.Tensor, group_size: Optional[int] = None) -> torch.Tensor:
def group_sort(
input: torch.Tensor, group_size: Optional[int] = None, dim: int = 1
) -> torch.Tensor:
r"""
Applies GroupSort activation on the given tensor.
See Also:
:py:func:`group_sort_2`
:py:func:`full_sort`
"""
if group_size is None or group_size > input.shape[1]:
group_size = input.shape[1]

if input.shape[1] % group_size != 0:
if group_size is None or group_size > input.shape[dim]:
group_size = input.shape[dim]

if input.shape[dim] % group_size != 0:
raise ValueError("The input size must be a multiple of the group size.")

fv = input.reshape([-1, group_size])
new_shape = (
input.shape[:dim]
+ (input.shape[dim] // group_size, group_size)
+ input.shape[dim + 1 :]
)
if group_size == 2:
sfv = torch.chunk(fv, 2, 1)
b = sfv[0]
c = sfv[1]
newv = torch.cat((torch.min(b, c), torch.max(b, c)), dim=1)
newv = newv.reshape(input.shape)
return newv
resh_input = input.view(new_shape)
a, b = (
torch.min(resh_input, dim + 1, keepdim=True)[0],
torch.max(resh_input, dim + 1, keepdim=True)[0],
)
return torch.cat([a, b], dim=dim + 1).view(input.shape)
fv = input.reshape(new_shape)

return torch.sort(fv)[0].reshape(input.shape)
return torch.sort(fv, dim=dim + 1)[0].reshape(input.shape)


def group_sort_2(input: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -568,3 +576,42 @@ def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor:
# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
return torch.where(labels > 0, pos_factor, neg_factor)


class SymmetricPad(torch.nn.Module):
"""
Pads a 2D tensor symmetrically.
Args:
pad (tuple): A tuple (pad_left, pad_right, pad_top, pad_bottom) specifying
the number of pixels to pad on each side. (or single int if
common padding).
onedim: False for conv2d, True for conv1d.
"""

def __init__(self, pad, onedim=False):
super().__init__()
self.onedim = onedim
num_dim = 2 if onedim else 4
if isinstance(pad, int):
self.pad = (pad,) * num_dim
else:
self.pad = torch.nn.modules.utils._reverse_repeat_tuple(pad, 2)
assert len(self.pad) == num_dim, f"Pad must be a tuple of {num_dim} integers"

def forward(self, x):

# Horizontal padding
left = x[:, ..., : self.pad[0]].flip(dims=[-1])
right = x[:, ..., -self.pad[1] :].flip(dims=[-1])
x = torch.cat([left, x, right], dim=-1)
if self.onedim:
return x
# Vertical padding
top = x[:, :, : self.pad[2], :].flip(dims=[-2])
bottom = x[:, :, -self.pad[3] :, :].flip(dims=[-2])
x = torch.cat([top, x, bottom], dim=-2)

return x
9 changes: 9 additions & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@
from .activation import FullSort
from .activation import GroupSort
from .activation import GroupSort2
from .activation import HouseHolder
from .activation import LPReLU
from .activation import MaxMin
from .conv import FrobeniusConv2d
from .conv import SpectralConv2d
from .conv import SpectralConv1d
from .conv import SpectralConvTranspose2d
from .downsampling import InvertibleDownSampling
from .linear import FrobeniusLinear
from .linear import SpectralLinear
Expand All @@ -72,4 +75,10 @@
from .pooling import ScaledAdaptiveAvgPool2d
from .pooling import ScaledAvgPool2d
from .pooling import ScaledL2NormPool2d
from .pooling import ScaledAdaptativeL2NormPool2d
from .upsampling import InvertibleUpSampling
from .normalization import LayerCentering
from .normalization import BatchCentering
from .unconstrained import PadConv2d
from .unconstrained import PadConv1d
from .residual import LipResidual
53 changes: 53 additions & 0 deletions deel/torchlip/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import torch
import torch.nn as nn
import numpy as np

from .. import functional as F
from .module import LipschitzModule
Expand Down Expand Up @@ -211,3 +212,55 @@ def vanilla_export(self):
layer = LPReLU(num_parameters=self.num_parameters)
layer.weight.data = self.weight.data
return layer


class HouseHolder(nn.Module, LipschitzModule):
def __init__(self, channels, k_coef_lip: float = 1.0, theta_initializer=None):
"""
Householder activation:
[this review](https://openreview.net/pdf?id=tD7eCtaSkR)
Adapted from [this repository](https://github.com/singlasahil14/SOC)
"""
nn.Module.__init__(self)
LipschitzModule.__init__(self, k_coef_lip)
assert (channels % 2) == 0
eff_channels = channels // 2

if isinstance(theta_initializer, float):
coef_theta = theta_initializer
else:
coef_theta = 0.5 * np.pi
self.theta = nn.Parameter(
coef_theta * torch.ones(eff_channels), requires_grad=True
)
if theta_initializer is not None:
if isinstance(theta_initializer, str):
name2init = {
"zeros": torch.nn.init.zeros_,
"ones": torch.nn.init.ones_,
"normal": torch.nn.init.normal_,
}
assert (
theta_initializer in name2init
), f"Unknown initializer {theta_initializer}"
name2init[theta_initializer](self.theta)
elif isinstance(theta_initializer, float):
pass
else:
raise ValueError(f"Unknown initializer {theta_initializer}")

def forward(self, z, axis=1):
theta_shape = (1, -1) + (1,) * (len(z.shape) - 2)
theta = self.theta.view(theta_shape)
x, y = z.split(z.shape[axis] // 2, axis)
selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))

a_2 = x * torch.cos(theta) + y * torch.sin(theta)
b_2 = x * torch.sin(theta) - y * torch.cos(theta)

a = x * (selector <= 0) + a_2 * (selector > 0)
b = y * (selector <= 0) + b_2 * (selector > 0)
return torch.cat([a, b], dim=axis)

def vanilla_export(self):
return self
Loading

0 comments on commit b23e090

Please sign in to comment.