diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 7931c896349de..945732a45c31a 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -212,8 +212,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." ) - # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 - default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) + # We assert that there is only one optimizer on fit start + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 if trainer.lr_scheduler_configs: diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 71ed41f52deac..2d18e043fe85e 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -1372,6 +1372,50 @@ def backward(self, loss): else: loss.backward(*args, **kwargs) + def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step + to prevent dangling gradients in multiple-optimizer setup. + + It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset. + + Args: + optimizer: The optimizer to toggle. + """ + # Iterate over all optimizer parameters to preserve their `requires_grad` information + # in case these are pre-defined during `configure_optimizers` + param_requires_grad_state = {} + for opt in self.trainer.optimizers: + for group in opt.param_groups: + for param in group["params"]: + # If a param already appear in param_requires_grad_state, continue + if param in param_requires_grad_state: + continue + param_requires_grad_state[param] = param.requires_grad + param.requires_grad = False + + # Then iterate over the current optimizer's parameters and set its `requires_grad` + # properties accordingly + for group in optimizer.param_groups: # type: ignore[union-attr] + for param in group["params"]: + param.requires_grad = param_requires_grad_state[param] + self._param_requires_grad_state = param_requires_grad_state + + def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. + + Args: + optimizer: The optimizer to untoggle. + """ + for opt in enumerate(self.trainer.optimizers): + # TODO: handle comparison when LightningOptimizer + if opt != optimizer: + for group in opt.param_groups: + for param in group["params"]: + if param in self._param_requires_grad_state: + param.requires_grad = self._param_requires_grad_state[param] + # save memory + self._param_requires_grad_state = {} + def clip_gradients( self, optimizer: Optimizer, diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 53d1cbf38d870..6c70f61462a9f 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -932,7 +932,7 @@ def training_step(self, batch, batch_idx): # Discriminator. optimizer_idx = 0 optimizer = self.optimizers()[optimizer_idx] - self.toggle_optimizer(optimizer, optimizer_idx) + self.toggle_optimizer(optimizer) loss_d = self.step(batch) self.log("loss_d", loss_d, prog_bar=True) @@ -940,12 +940,12 @@ def training_step(self, batch, batch_idx): optimizer.zero_grad() self.manual_backward(loss_d) optimizer.step() - self.untoggle_optimizer(optimizer_idx) + self.untoggle_optimizer(optimizer) # Generator. optimizer_idx = 1 optimizer = self.optimizers()[optimizer_idx] - self.toggle_optimizer(optimizer, optimizer_idx) + self.toggle_optimizer(optimizer) loss_g = self.step(batch) self.log("loss_g", loss_g, prog_bar=True) @@ -953,7 +953,7 @@ def training_step(self, batch, batch_idx): optimizer.zero_grad() self.manual_backward(loss_g) optimizer.step() - self.untoggle_optimizer(optimizer_idx) + self.untoggle_optimizer(optimizer) def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)