Skip to content

Commit

Permalink
modify vanilla_model functio to support parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Dec 9, 2024
1 parent c6bf80f commit 16b0059
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deel/torchlip/modules/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class InvertibleDownSampling(torch.nn.PixelUnshuffle, LipschitzModule):
def __init__(self, kernel_size: int, k_coef_lip: float = 1.0):
torch.nn.PixelUnshuffle.__init__(self, downscale_factor = kernel_size)
torch.nn.PixelUnshuffle.__init__(self, downscale_factor=kernel_size)
LipschitzModule.__init__(self, k_coef_lip)

def vanilla_export(self):
Expand Down
7 changes: 3 additions & 4 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def vanilla_model(model: nn.Module):
model (nn.Module): Lipschitz neural network
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)

if isinstance(module, LipschitzModule):
# simple module
setattr(model, n, module.vanilla_export())
elif len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)


class _LipschitzCoefMultiplication(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion deel/torchlip/modules/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

class InvertibleUpSampling(torch.nn.PixelShuffle, LipschitzModule):
def __init__(self, kernel_size: int, k_coef_lip: float = 1.0):
torch.nn.PixelShuffle.__init__(self, upscale_factor = kernel_size)
torch.nn.PixelShuffle.__init__(self, upscale_factor=kernel_size)
LipschitzModule.__init__(self, k_coef_lip)

def vanilla_export(self):
Expand Down

0 comments on commit 16b0059

Please sign in to comment.