diff --git a/CHANGELOG.md b/CHANGELOG.md index 82ce08594b3104..67e5711ff2068f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function - Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 598198bbfe7943..d2155bdeef21bb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -726,7 +726,6 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi # ------------------- # calculate loss (train step + train step end) # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): @@ -739,6 +738,9 @@ def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimi else: if self.trainer.lightning_module.automatic_optimization: self.optimizer_step(optimizer, opt_idx, batch_idx, closure) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.lightning_module.untoggle_optimizer(opt_idx) else: result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) @@ -839,10 +841,6 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, "training_step returned None. If this was on purpose, ignore this warning..." ) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - return result def _check_finite(self, loss: torch.Tensor) -> None: