Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jan 29, 2023
1 parent 4fdba24 commit 4420593
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
# We assert that there is only one optimizer on fit start
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

if trainer.lr_scheduler_configs:
Expand Down
44 changes: 44 additions & 0 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,50 @@ def backward(self, loss):
else:
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step
to prevent dangling gradients in multiple-optimizer setup.
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
Args:
optimizer: The optimizer to toggle.
"""
# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
param_requires_grad_state = {}
for opt in self.trainer.optimizers:
for group in opt.param_groups:
for param in group["params"]:
# If a param already appear in param_requires_grad_state, continue
if param in param_requires_grad_state:
continue
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

# Then iterate over the current optimizer's parameters and set its `requires_grad`
# properties accordingly
for group in optimizer.param_groups: # type: ignore[union-attr]
for param in group["params"]:
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
Args:
optimizer: The optimizer to untoggle.
"""
for opt in enumerate(self.trainer.optimizers):
# TODO: handle comparison when LightningOptimizer
if opt != optimizer:
for group in opt.param_groups:
for param in group["params"]:
if param in self._param_requires_grad_state:
param.requires_grad = self._param_requires_grad_state[param]
# save memory
self._param_requires_grad_state = {}

def clip_gradients(
self,
optimizer: Optimizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,28 +932,28 @@ def training_step(self, batch, batch_idx):
# Discriminator.
optimizer_idx = 0
optimizer = self.optimizers()[optimizer_idx]
self.toggle_optimizer(optimizer, optimizer_idx)
self.toggle_optimizer(optimizer)

loss_d = self.step(batch)
self.log("loss_d", loss_d, prog_bar=True)

optimizer.zero_grad()
self.manual_backward(loss_d)
optimizer.step()
self.untoggle_optimizer(optimizer_idx)
self.untoggle_optimizer(optimizer)

# Generator.
optimizer_idx = 1
optimizer = self.optimizers()[optimizer_idx]
self.toggle_optimizer(optimizer, optimizer_idx)
self.toggle_optimizer(optimizer)

loss_g = self.step(batch)
self.log("loss_g", loss_g, prog_bar=True)

optimizer.zero_grad()
self.manual_backward(loss_g)
optimizer.step()
self.untoggle_optimizer(optimizer_idx)
self.untoggle_optimizer(optimizer)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down

0 comments on commit 4420593

Please sign in to comment.