Skip to content

Commit

Permalink
feat: add in-place vanilla model conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
y-prudent committed Nov 30, 2023
1 parent 80a75a0 commit 5ab989b
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions deel/torchlip/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@
logger = logging.getLogger("deel.torchlip")


def vanilla_model(model: nn.Module):
"""Convert lipschitz modules into their non-lipschitz counterpart (for
instance, SpectralConv2d layers become Conv2d layers).
Warning: This function modifies the model in-place.
Args:
model (nn.Module): Lipschitz neural network
"""
for n, module in model.named_children():
if len(list(module.children())) > 0:
# compound module, go inside it
vanilla_model(module)

if isinstance(module, LipschitzModule):
# simple module
setattr(model, n, module.vanilla_export())


class _LipschitzCoefMultiplication(nn.Module):
"""Parametrization module for lipschitz global coefficient multiplication."""

Expand Down

0 comments on commit 5ab989b

Please sign in to comment.