From 6a5658649208f6e0caa214326389d6f1e7cc02f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Feb 2023 17:21:01 +0100 Subject: [PATCH] Make manual optimization mandatory for multiple optimizers (#16539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- .../computer_vision_fine_tuning.py | 2 +- src/pytorch_lightning/CHANGELOG.md | 20 ++ src/pytorch_lightning/callbacks/callback.py | 2 +- src/pytorch_lightning/callbacks/finetuning.py | 19 +- .../callbacks/stochastic_weight_avg.py | 4 +- src/pytorch_lightning/core/hooks.py | 5 +- src/pytorch_lightning/core/module.py | 210 ++++-------------- src/pytorch_lightning/core/optimizer.py | 50 ++--- .../loops/epoch/training_epoch_loop.py | 181 +-------------- src/pytorch_lightning/loops/fit_loop.py | 20 +- .../loops/optimization/manual_loop.py | 14 -- .../loops/optimization/optimizer_loop.py | 146 ++---------- src/pytorch_lightning/loops/progress.py | 9 - src/pytorch_lightning/loops/utilities.py | 71 +----- .../plugins/precision/amp.py | 11 +- .../plugins/precision/colossalai.py | 4 +- .../plugins/precision/deepspeed.py | 9 +- .../plugins/precision/ipu.py | 7 +- .../plugins/precision/precision_plugin.py | 23 +- .../plugins/precision/tpu.py | 3 +- .../strategies/colossalai.py | 5 +- src/pytorch_lightning/strategies/ddp.py | 4 +- src/pytorch_lightning/strategies/deepspeed.py | 20 +- .../strategies/hpu_parallel.py | 3 +- .../strategies/single_hpu.py | 3 +- src/pytorch_lightning/strategies/strategy.py | 17 +- src/pytorch_lightning/trainer/trainer.py | 8 - src/pytorch_lightning/tuner/lr_finder.py | 9 +- .../utilities/migration/migration.py | 28 ++- .../utilities/signature_utils.py | 2 +- src/pytorch_lightning/utilities/types.py | 2 - tests/tests_pytorch/accelerators/test_ipu.py | 3 + .../callbacks/test_finetuning_callback.py | 6 +- .../callbacks/test_lr_monitor.py | 2 +- .../core/test_lightning_module.py | 6 +- .../core/test_lightning_optimizer.py | 6 +- .../helpers/deterministic_model.py | 4 +- .../loops/epoch/test_training_epoch_loop.py | 113 ---------- .../loops/optimization/test_closure.py | 4 +- .../loops/optimization/test_optimizer_loop.py | 170 +------------- .../loops/test_evaluation_loop_flow.py | 24 +- .../loops/test_loop_state_dict.py | 1 - tests/tests_pytorch/loops/test_loops.py | 93 ++------ .../loops/test_training_loop_flow_dict.py | 16 +- .../loops/test_training_loop_flow_scalar.py | 28 +-- tests/tests_pytorch/models/test_hooks.py | 20 +- .../plugins/precision/test_tpu.py | 2 +- .../tests_pytorch/plugins/test_amp_plugins.py | 2 +- .../strategies/test_colossalai.py | 3 +- .../strategies/test_deepspeed_strategy.py | 9 +- .../test_multiple_eval_dataloaders.py | 59 ----- .../optimization/test_manual_optimization.py | 16 +- .../optimization/test_multiple_optimizers.py | 135 +---------- .../trainer/optimization/test_optimizers.py | 187 ++-------------- tests/tests_pytorch/tuner/test_lr_finder.py | 4 +- .../utilities/migration/test_migration.py | 32 ++- 56 files changed, 344 insertions(+), 1512 deletions(-) diff --git a/examples/pl_domain_templates/computer_vision_fine_tuning.py b/examples/pl_domain_templates/computer_vision_fine_tuning.py index afcfa8f90066b..53790d1f81a4e 100644 --- a/examples/pl_domain_templates/computer_vision_fine_tuning.py +++ b/examples/pl_domain_templates/computer_vision_fine_tuning.py @@ -75,7 +75,7 @@ def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False): def freeze_before_training(self, pl_module: LightningModule): self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer): if epoch == self.milestones[0]: # unfreeze 5 last layers self.unfreeze_and_add_param_group( diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 3c19e707f09d1..115d47f26628d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560)) +- Manual optimization is now required for working with multiple optimizers ([#16539](https://github.com/Lightning-AI/lightning/pull/16539)) + + ### Deprecated - @@ -173,6 +176,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `using_lbfgs` argument from `LightningModule.optimizer_step` hook ([#16538](https://github.com/Lightning-AI/lightning/pull/16538)) +- Removed support for multiple optimizers in automatic optimization mode ([#16539](https://github.com/Lightning-AI/lightning/pull/16539)) + * Removed `opt_idx` argument from `BaseFinetuning.finetune_function` callback method + * Removed `opt_idx` argument from `Callback.on_before_optimizer_step` callback method + * Removed `optimizer_idx` as an optional argument in `LightningModule.training_step` + * Removed `optimizer_idx` argument from `LightningModule.on_before_optimizer_step` + * Removed `optimizer_idx` argument from `LightningModule.configure_gradient_clipping` + * Removed `optimizer_idx` argument from `LightningModule.optimizer_step` + * Removed `optimizer_idx` argument from `LightningModule.optimizer_zero_grad` + * Removed `optimizer_idx` argument from `LightningModule.lr_scheduler_step` + * Removed support for declaring optimizer frequencies in the dictionary returned from `LightningModule.configure_optimizers` + * Removed arguments `optimizer` and `optimizer_idx` from `LightningModule.backward` + * Removed `optimizer_idx` argument from `PrecisionPlugin.optimizer_step` and all of its overrides in subclasses + * Removed `optimizer_idx` argument from `PrecisionPlugin.{optimizer_step,backward}` and all of its overrides in subclasses + * Removed `optimizer_idx` argument from `Strategy.{optimizer_step,backward}` and all of its overrides in subclasses + * Removed `Trainer.optimizer_frequencies` attribute + + ### Fixed - Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369)) diff --git a/src/pytorch_lightning/callbacks/callback.py b/src/pytorch_lightning/callbacks/callback.py index d8cfdb5399ca6..dda36f12809e9 100644 --- a/src/pytorch_lightning/callbacks/callback.py +++ b/src/pytorch_lightning/callbacks/callback.py @@ -241,7 +241,7 @@ def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul """Called after ``loss.backward()`` and before optimizers are stepped.""" def on_before_optimizer_step( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer ) -> None: """Called before ``optimizer.step()``.""" diff --git a/src/pytorch_lightning/callbacks/finetuning.py b/src/pytorch_lightning/callbacks/finetuning.py index f115c33bb2f82..72f70c3b30456 100644 --- a/src/pytorch_lightning/callbacks/finetuning.py +++ b/src/pytorch_lightning/callbacks/finetuning.py @@ -46,7 +46,7 @@ class BaseFinetuning(Callback): and should be used to freeze any modules parameters. ``finetune_function``: This method is called on every train epoch start and should be used to - ``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group`` + ``unfreeze`` any parameters. Those parameters need to be added in a new ``param_group`` within the optimizer. .. note:: Make sure to filter the parameters based on ``requires_grad``. @@ -69,7 +69,7 @@ class BaseFinetuning(Callback): ... # Here, we are freezing `feature_extractor` ... self.freeze(pl_module.feature_extractor) ... - ... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): + ... def finetune_function(self, pl_module, current_epoch, optimizer): ... # When `current_epoch` is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( @@ -290,18 +290,13 @@ def _store( def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the epoch begins.""" - # import is here to avoid circular imports - from pytorch_lightning.loops.utilities import _get_active_optimizers - - for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0): + for opt_idx, optimizer in enumerate(trainer.optimizers): num_param_groups = len(optimizer.param_groups) - self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) + self.finetune_function(pl_module, trainer.current_epoch, optimizer) current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups) - def finetune_function( - self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int - ) -> None: + def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None: """Override to add your unfreeze logic.""" raise NotImplementedError @@ -389,9 +384,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: self.freeze(pl_module.backbone) - def finetune_function( - self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int - ) -> None: + def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None: """Called when the epoch begins.""" if epoch == self.unfreeze_backbone_at_epoch: current_lr = optimizer.param_groups[0]["lr"] diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 7931c896349de..945732a45c31a 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -212,8 +212,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." ) - # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 - default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) + # We assert that there is only one optimizer on fit start + default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 if trainer.lr_scheduler_configs: diff --git a/src/pytorch_lightning/core/hooks.py b/src/pytorch_lightning/core/hooks.py index 025e9bb74c5ca..acf59ed9414df 100644 --- a/src/pytorch_lightning/core/hooks.py +++ b/src/pytorch_lightning/core/hooks.py @@ -229,7 +229,7 @@ def on_after_backward(self) -> None: Use the ``on_before_optimizer_step`` if you need the unscaled gradients. """ - def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. If using gradient accumulation, the hook is called once the gradients have been accumulated. @@ -243,11 +243,10 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> Args: optimizer: Current optimizer being used. - optimizer_idx: Index of the current optimizer being used. Example:: - def on_before_optimizer_step(self, optimizer, optimizer_idx): + def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 8c957960d9ef6..549166c2f2e05 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -35,7 +35,6 @@ from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_fabric.utilities.distributed import _distributed_available, _sync_ddp from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0 -from lightning_fabric.utilities.types import Steppable from lightning_fabric.wrappers import _FabricOptimizer from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks @@ -660,7 +659,6 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: # type: igno batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (``int``): Integer displaying index of this batch - optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. Return: Any of. @@ -681,23 +679,25 @@ def training_step(self, batch, batch_idx): loss = self.loss(out, x) return loss - If you define multiple optimizers, this step will be called with an additional - ``optimizer_idx`` parameter. + To use multiple optimizers, you can switch to 'manual optimization' and control their stepping: .. code-block:: python + def __init__(self): + super().__init__() + self.automatic_optimization = False + + # Multiple optimizers (e.g.: GANs) - def training_step(self, batch, batch_idx, optimizer_idx): - if optimizer_idx == 0: - # do training_step with encoder - ... - if optimizer_idx == 1: - # do training_step with decoder - ... + def training_step(self, batch, batch_idx): + opt1, opt2 = self.optimizers() - Note: - The loss value shown in the progress bar is smoothed (averaged) over the last values, - so it differs from the actual loss returned in train/validation step. + # do training_step with encoder + ... + opt1.step() + # do training_step with decoder + ... + opt2.step() Note: When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically @@ -1228,6 +1228,7 @@ def configure_optimizers(self) -> Any: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. + Optimization with multiple optimizers only works in the manual optimization mode. Return: Any of these 6 options. @@ -1238,7 +1239,6 @@ def configure_optimizers(self) -> Any: (or multiple ``lr_scheduler_config``). - **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"`` key whose value is a single LR scheduler or ``lr_scheduler_config``. - - **Tuple of dictionaries** as described above, with an optional ``"frequency"`` key. - **None** - Fit will run without any optimizer. The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration. @@ -1311,90 +1311,18 @@ def configure_optimizers(self): Metrics can be made available to monitor by simply logging it using ``self.log('metric_to_track', metric_val)`` in your :class:`~pytorch_lightning.core.module.LightningModule`. - Note: - The ``frequency`` value specified in a dict along with the ``optimizer`` key is an int corresponding - to the number of sequential batches optimized with the specific optimizer. - It should be given to none or to all of the optimizers. - There is a difference between passing multiple optimizers in a list, - and passing multiple optimizers in dictionaries with a frequency of 1: - - - In the former case, all optimizers will operate on the given batch in each optimization step. - - In the latter, only one optimizer will operate on the given batch at every step. - - This is different from the ``frequency`` value specified in the ``lr_scheduler_config`` mentioned above. - - .. code-block:: python - - def configure_optimizers(self): - optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) - optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) - return [ - {"optimizer": optimizer_one, "frequency": 5}, - {"optimizer": optimizer_two, "frequency": 10}, - ] - - In this example, the first optimizer will be used for the first 5 steps, - the second optimizer for the next 10 steps and that cycle will continue. - If an LR scheduler is specified for an optimizer using the ``lr_scheduler`` key in the above dict, - the scheduler will only be updated when its optimizer is being used. - - Examples:: - - # most cases. no learning rate scheduler - def configure_optimizers(self): - return Adam(self.parameters(), lr=1e-3) - - # multiple optimizer case (e.g.: GAN) - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_dis.parameters(), lr=0.02) - return gen_opt, dis_opt - - # example with learning rate schedulers - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_dis.parameters(), lr=0.02) - dis_sch = CosineAnnealing(dis_opt, T_max=10) - return [gen_opt, dis_opt], [dis_sch] - - # example with step-based learning rate schedulers - # each optimizer has its own scheduler - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_dis.parameters(), lr=0.02) - gen_sch = { - 'scheduler': ExponentialLR(gen_opt, 0.99), - 'interval': 'step' # called after each training step - } - dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch - return [gen_opt, dis_opt], [gen_sch, dis_sch] - - # example with optimizer frequencies - # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 - # https://arxiv.org/abs/1704.00028 - def configure_optimizers(self): - gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_dis.parameters(), lr=0.02) - n_critic = 5 - return ( - {'optimizer': dis_opt, 'frequency': n_critic}, - {'optimizer': gen_opt, 'frequency': 1} - ) - Note: Some things to know: - - Lightning calls ``.backward()`` and ``.step()`` on each optimizer as needed. - - If learning rate scheduler is specified in ``configure_optimizers()`` with key + - Lightning calls ``.backward()`` and ``.step()`` automatically in case of automatic optimization. + - If a learning rate scheduler is specified in ``configure_optimizers()`` with key ``"interval"`` (default "epoch") in the scheduler configuration, Lightning will call the scheduler's ``.step()`` method automatically in case of automatic optimization. - - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers. - - If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter. + - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizer. - If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you. - - If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer - at each training step. - - If you need to control how often those optimizers step or override the default ``.step()`` schedule, - override the :meth:`optimizer_step` hook. + - If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them + yourself. + - If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook. """ rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") @@ -1423,23 +1351,19 @@ def training_step(...): self._fabric.backward(loss, *args, **kwargs) else: self._verify_is_manual_optimization("manual_backward") - self.trainer.strategy.backward(loss, None, None, *args, **kwargs) + self.trainer.strategy.backward(loss, None, *args, **kwargs) - def backward( - self, loss: Tensor, optimizer: Optional[Steppable], optimizer_idx: Optional[int], *args: Any, **kwargs: Any - ) -> None: + def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None: """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own implementation if you need to. Args: loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps). - optimizer: Current optimizer being used. ``None`` if using manual optimization. - optimizer_idx: Index of the current optimizer being used. ``None`` if using manual optimization. Example:: - def backward(self, loss, optimizer, optimizer_idx): + def backward(self, loss): loss.backward() """ if self._fabric: @@ -1550,7 +1474,6 @@ def clip_gradients( def configure_gradient_clipping( self, optimizer: Optimizer, - optimizer_idx: int, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, ) -> None: @@ -1558,36 +1481,27 @@ def configure_gradient_clipping( Args: optimizer: Current optimizer being used. - optimizer_idx: Index of the current optimizer being used. - gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer + gradient_clip_val: The value at which to clip gradients. By default, value passed in Trainer will be available here. - gradient_clip_algorithm: The gradient clipping algorithm to use. By default value + gradient_clip_algorithm: The gradient clipping algorithm to use. By default, value passed in Trainer will be available here. Example:: - # Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN - def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): - if optimizer_idx == 1: - # Lightning will handle the gradient clipping - self.clip_gradients( - optimizer, - gradient_clip_val=gradient_clip_val, - gradient_clip_algorithm=gradient_clip_algorithm - ) - else: - # implement your own custom logic to clip gradients for generator (optimizer_idx=0) + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): + # Implement your own custom logic to clip gradients + # You can call `self.clip_gradients` with your settings: + self.clip_gradients( + optimizer, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm + ) """ self.clip_gradients( optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) - def lr_scheduler_step( - self, - scheduler: LRSchedulerTypeUnion, - optimizer_idx: int, - metric: Optional[Any], - ) -> None: + def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None: r""" Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler. @@ -1596,20 +1510,19 @@ def lr_scheduler_step( Args: scheduler: Learning rate scheduler. - optimizer_idx: Index of the optimizer associated with this scheduler. metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``. Examples:: # DEFAULT - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if metric is None: scheduler.step() else: scheduler.step(metric) # Alternative way to update schedulers if it requires an epoch value - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): scheduler.step(epoch=self.current_epoch) """ @@ -1623,14 +1536,13 @@ def optimizer_step( epoch: int, batch_idx: int, optimizer: Union[Optimizer, LightningOptimizer], - optimizer_idx: int = 0, optimizer_closure: Optional[Callable[[], Any]] = None, ) -> None: r""" Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls - each optimizer. + the optimizer. - By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. + By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example. This method (and ``zero_grad()``) won't be called during the accumulation phase when ``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization. @@ -1638,47 +1550,17 @@ def optimizer_step( epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer - optimizer_idx: If you used multiple optimizers, this indexes into that list. optimizer_closure: The optimizer closure. This closure must be executed as it includes the calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``. Examples:: # DEFAULT - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): optimizer.step(closure=optimizer_closure) - # Alternating schedule for optimizer steps (i.e.: GANs) - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure): - # update generator opt every step - if optimizer_idx == 0: - optimizer.step(closure=optimizer_closure) - - # update discriminator opt every 2 steps - if optimizer_idx == 1: - if (batch_idx + 1) % 2 == 0 : - optimizer.step(closure=optimizer_closure) - else: - # call the closure by itself to run `training_step` + `backward` without an optimizer step - optimizer_closure() - - # ... - # add as many optimizers as you want - - Here's another example showing how to use this for more advanced things such as - learning rate warm-up: - - .. code-block:: python - - # learning rate warm-up - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - ): + # Learning rate warm-up + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # update params optimizer.step(closure=optimizer_closure) @@ -1687,27 +1569,25 @@ def optimizer_step( lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.learning_rate - """ optimizer.step(closure=optimizer_closure) - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None: + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None: """Override this method to change the default behaviour of ``optimizer.zero_grad()``. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer - optimizer_idx: If you used multiple optimizers this indexes into that list. Examples:: # DEFAULT - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): optimizer.zero_grad() # Set gradients to `None` instead of zero to improve performance (not required on `torch>=2.0.0`). - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): optimizer.zero_grad(set_to_none=True) See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example. diff --git a/src/pytorch_lightning/core/optimizer.py b/src/pytorch_lightning/core/optimizer.py index 0aabc0bce7368..bd1c218304487 100644 --- a/src/pytorch_lightning/core/optimizer.py +++ b/src/pytorch_lightning/core/optimizer.py @@ -157,7 +157,7 @@ def closure_dis(): raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable") assert self._strategy is not None - step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs) self._on_after_step() @@ -166,7 +166,7 @@ def closure_dis(): def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: +) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -176,21 +176,22 @@ def _init_optimizers_and_lr_schedulers( ) optim_conf = _MockOptimizer() - optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf) + optimizers, lr_schedulers, monitor = _configure_optimizers(optim_conf) lr_scheduler_configs = ( _configure_schedulers_automatic_opt(lr_schedulers, monitor) if model.automatic_optimization else _configure_schedulers_manual_opt(lr_schedulers) ) - _set_scheduler_opt_idx(optimizers, lr_scheduler_configs) + _validate_multiple_optimizers_support(optimizers, model) + _validate_optimizers_attached(optimizers, lr_scheduler_configs) _validate_scheduler_api(lr_scheduler_configs, model) - return optimizers, lr_scheduler_configs, optimizer_frequencies + return optimizers, lr_scheduler_configs def _configure_optimizers( optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] -) -> Tuple[List, List, List, Optional[str]]: - optimizers, lr_schedulers, optimizer_frequencies = [], [], [] +) -> Tuple[List, List, Optional[str]]: + optimizers, lr_schedulers = [], [] monitor = None # single output, single optimizer @@ -228,12 +229,6 @@ def _configure_optimizers( for opt_idx, opt_dict in enumerate(optim_conf) if "lr_scheduler" in opt_dict ] - optimizer_frequencies = [ - opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None - ] - # assert that if frequencies are present, they are given for all optimizers - if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): - raise ValueError("A frequency must be given to each optimizer.") # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf): optimizers = list(optim_conf) @@ -246,9 +241,8 @@ def _configure_optimizers( " * [`Optimizer`]\n" " * ([`Optimizer`], [`LRScheduler`])\n" ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n' - ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) - return optimizers, lr_schedulers, optimizer_frequencies, monitor + return optimizers, lr_schedulers, monitor def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: @@ -315,7 +309,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig if isinstance(scheduler, dict): # interval is not in this list even though the user needs to manually call the scheduler because # the `LearningRateMonitor` callback needs to check its value to know when to log the learning rate - invalid_keys = {"frequency", "reduce_on_plateau", "monitor", "strict"} + invalid_keys = {"reduce_on_plateau", "monitor", "strict"} keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys] if keys_to_warn: @@ -349,27 +343,25 @@ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model ) -def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: - for config in lr_scheduler_configs: +def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None: + if model.automatic_optimization and len(optimizers) > 1: + raise RuntimeError( + "Training with multiple optimizers is only supported with manual optimization. Set" + " `self.automatic_optimization = False`, then access your optimizers in `training_step` with" + " `opt1, opt2, ... = self.optimizers()`." + ) - for opt_idx, opt in enumerate(optimizers): - if config.scheduler.optimizer is opt: - if config.opt_idx is not None and config.opt_idx != opt_idx: - raise MisconfigurationException( - "`opt_idx` set inside scheduler config does not match with the index" - " of the respective optimizer returned from `configure_optimizers`." - ) - config.opt_idx = opt_idx - break - else: +def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: + for config in lr_scheduler_configs: + if config.scheduler.optimizer not in optimizers: raise MisconfigurationException( "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." ) def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: - valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"} + valid_keys = {"optimizer", "lr_scheduler", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: rank_zero_warn( diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index ede1fff787ee7..a334e2672e990 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from collections import defaultdict, OrderedDict -from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, Tuple, Union +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union -import numpy as np import torch -from lightning_utilities.core.apply_func import apply_to_collection -import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE from pytorch_lightning.loops.progress import BatchProgress, SchedulerProgress -from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached +from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher @@ -223,12 +220,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: self.batch_progress.increment_started() with self.trainer.profiler.profile("run_training_batch"): - # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: - optimizers = _get_active_optimizers( - self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0) - ) - batch_output = self.optimizer_loop.run(optimizers, kwargs) + # in automatic optimization, there can only be one optimizer + batch_output = self.optimizer_loop.run(self.trainer.optimizers[0], kwargs) else: batch_output = self.manual_loop.run(kwargs) @@ -240,14 +234,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: if self._num_ready_batches_reached(): self.update_lr_schedulers("epoch", update_plateau_schedulers=False) - batch_end_outputs = self._prepare_outputs_training_batch_end( - batch_output, - lightning_module=self.trainer.lightning_module, - num_optimizers=len(self.trainer.optimizers), - ) - - self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx) - self.trainer._call_lightning_module_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx) + self.trainer._call_callback_hooks("on_train_batch_end", batch_output, batch, batch_idx) + self.trainer._call_lightning_module_hook("on_train_batch_end", batch_output, batch, batch_idx) self.trainer._logger_connector.on_batch_end() self.batch_progress.increment_completed() @@ -338,73 +326,13 @@ def _should_accumulate(self) -> bool: strategy_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch return not accumulation_done and strategy_accumulates_on_final_batch - @staticmethod - def _prepare_outputs_training_batch_end( - batch_output: _BATCH_OUTPUTS_TYPE, - lightning_module: "pl.LightningModule", - num_optimizers: int, - ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: - """Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook.""" - if not batch_output: - return [] # type: ignore[return-value] - - # convert optimizer dicts to list - if lightning_module.automatic_optimization: - batch_output = apply_to_collection( - batch_output, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers - ) - - array = np.array(batch_output, dtype=object) - # squeeze all single-element dimensions - array = array.squeeze() - array = array.tolist() - array = _recursive_unpad(array) - return array - - @staticmethod - def _prepare_outputs_training_epoch_end( - batch_outputs: _OUTPUTS_TYPE, - lightning_module: "pl.LightningModule", - num_optimizers: int, - ) -> Union[List[List[List[Dict[str, Any]]]], List[List[Dict[str, Any]]], List[Dict[str, Any]]]: - """Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook.""" - # `batch_outputs` (plural) is the same as `epoch_end_output` (singular) - if not batch_outputs: - return [] # type: ignore[return-value] - - # convert optimizer dicts to list - if lightning_module.automatic_optimization: - batch_outputs = apply_to_collection( - batch_outputs, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers - ) - - array = _recursive_pad(batch_outputs) - # squeeze all single-element dimensions - array = array.squeeze() - array = array.tolist() - array = _recursive_unpad(array) - # in case we squeezed from 1-element array to a 0-dim array - array = array if isinstance(array, list) else [array] - # remove residual empty lists - array = [item for item in array if not isinstance(item, list) or len(item)] - return array - def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: """updates the lr schedulers based on the given interval.""" if interval == "step" and self._should_accumulate(): return - active_optimizers = _get_active_optimizers( - self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx - ) - self._update_learning_rates( - interval=interval, - update_plateau_schedulers=update_plateau_schedulers, - opt_indices=[opt_idx for opt_idx, _ in active_optimizers], - ) + self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers) - def _update_learning_rates( - self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None - ) -> None: + def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) -> None: """Update learning rates. Args: @@ -413,18 +341,11 @@ def _update_learning_rates( This is used so non-plateau schedulers can be updated before running validation. Checkpoints are commonly saved during validation, however, on-plateau schedulers might monitor a validation metric so they have to be updated separately. - opt_indices: indices of the optimizers to update. """ if not self.trainer.lr_scheduler_configs or not self.trainer.lightning_module.automatic_optimization: return - if opt_indices is None: - opt_indices = [] - for config in self.trainer.lr_scheduler_configs: - if config.opt_idx not in opt_indices: - continue - if update_plateau_schedulers ^ config.reduce_on_plateau: continue @@ -460,7 +381,6 @@ def _update_learning_rates( self.trainer._call_lightning_module_hook( "lr_scheduler_step", config.scheduler, - config.opt_idx, monitor_val, ) self.scheduler_progress.increment_completed() @@ -520,87 +440,8 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde """ kwargs["batch"] = batch training_step_fx = getattr(self.trainer.lightning_module, "training_step") - # the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the - # user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names + # the `batch_idx` is optional, but its name can be anything + # as long as there are two argumetns after 'self', we assume they are the `batch` and `batch_idx` if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): kwargs["batch_idx"] = batch_idx return kwargs - - -def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]: - """Converts an optimizer dict to a list in which the key of the dict determines the position of the element. - - Example:: - >>> _convert_optim_dict({0: {"loss": 0.0}, 2: {"loss": 0.2}}, num_optimizers=3) - [{'loss': 0.0}, None, {'loss': 0.2}] - """ - return [outs[opt_idx] if opt_idx in outs else None for opt_idx in range(num_optimizers)] - - -@overload -def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]: - ... - - -@overload -def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any: - ... - - -def _recursive_unpad(nested: Union[Any, List[Any]], value: Optional[Any] = None) -> Union[Any, List[Any]]: - """Removes the given pad value from the nested list. Not strictly the reverse operation of - :func:`_recursive_pad` because it removes the padding element everywhere, not just from the end of a list. - - Example:: - >>> _recursive_unpad([[[0, 1, 0]], [2], [0, 0]], value=0) - [[[1]], [2], []] - """ - if not isinstance(nested, list): - return nested - - return [_recursive_unpad(item, value) for item in nested if item != value] - - -def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.ndarray: - """Pads a jagged nested list of lists with the given value such that a proper multi-dimensional array can be - formed with rectangular shape. The padding appends to the incomplete lists. - - Example:: - >>> _recursive_pad([[], [1], [2, 3], [4]], fill_value=0) # doctest: +NORMALIZE_WHITESPACE - array([[0, 0], [1, 0], [2, 3], [4, 0]], dtype=object) - """ - # code adapted from stackexchange: - # https://codereview.stackexchange.com/questions/222623/pad-a-ragged-multidimensional-array-to-rectangular-shape - dimensions = _get_max_shape(nested) - result = np.full(dimensions, fill_value, dtype=object) - for index, value in _iterate_nested_array(nested): - result[index] = value - return result - - -def _get_dimensions(array: List[Any], level: int = 0) -> Generator: - yield level, len(array) - if all(isinstance(row, list) for row in array): - for row in array: - yield from _get_dimensions(row, level + 1) - - -def _get_max_shape(array: List[Any]) -> List[int]: - """Calculates the max size in each dimension of a jagged (non-rectangular) nested list of lists. - - Example:: - >>> _get_max_shape([[], [[1], [2]], []]) - [3, 2, 1] - """ - dimensions: DefaultDict[int, int] = defaultdict(int) - for level, length in _get_dimensions(array): - dimensions[level] = max(dimensions[level], length) - return [value for _, value in sorted(dimensions.items())] - - -def _iterate_nested_array(array: List[Any], index: Tuple = ()) -> Generator: - if all(isinstance(item, list) for item in array): - for idx, row in enumerate(array): - yield from _iterate_nested_array(row, (*index, idx)) - else: # final level - yield (*index, slice(len(array))), array diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 572a61b824e58..9fc3ac78a4468 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -43,13 +43,12 @@ class _FitLoop(_Loop): for epoch in range(max_epochs): # TrainingEpochLoop for batch_idx, batch in enumerate(train_dataloader): - # OptimizerLoop - for optimizer_idx, opt in enumerate(optimizers): - loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) - ... + loss = lightning_module.training_step(batch, batch_idx) + ... + # ValidationEpochLoop for batch_idx, batch in enumerate(val_dataloader): - lightning_module.validation_step(batch, batch_idx, optimizer_idx) + lightning_module.validation_step(batch, batch_idx) ... ... ... @@ -283,15 +282,8 @@ def on_advance_end(self) -> None: # get the model and call model.training_epoch_end model = self.trainer.lightning_module if is_overridden("training_epoch_end", model) and self._outputs: - epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end( - self._outputs, - lightning_module=model, - num_optimizers=len(self.trainer.optimizers), - ) - # run lightning module hook training_epoch_end - # refresh the result for custom logging at the epoch level - epoch_end_outputs = self.trainer._call_lightning_module_hook("training_epoch_end", epoch_end_outputs) - if epoch_end_outputs is not None: + return_value = self.trainer._call_lightning_module_hook("training_epoch_end", self._outputs) + if return_value is not None: raise MisconfigurationException( "`training_epoch_end` expects a return of None. " "HINT: remove the return statement in `training_epoch_end`." diff --git a/src/pytorch_lightning/loops/optimization/manual_loop.py b/src/pytorch_lightning/loops/optimization/manual_loop.py index e5c7ec4da9364..0f38873324543 100644 --- a/src/pytorch_lightning/loops/optimization/manual_loop.py +++ b/src/pytorch_lightning/loops/optimization/manual_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.loops import _Loop from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.progress import Progress, ReadyCompletedTracker -from pytorch_lightning.loops.utilities import _build_training_step_kwargs from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -103,8 +102,6 @@ def advance(self, kwargs: OrderedDict) -> None: Args: kwargs: The kwargs passed down to the hooks. """ - kwargs = self._build_kwargs(kwargs) - # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values()) del kwargs # release the batch from memory @@ -134,14 +131,3 @@ def _on_before_step(self) -> None: def _on_after_step(self) -> None: self.trainer.profiler.stop("optimizer_step") self.optim_step_progress.increment_completed() - - def _build_kwargs(self, kwargs: OrderedDict) -> OrderedDict: - """Helper method to build the arguments for the current step. - - Args: - kwargs: The kwargs passed down to the hooks. - - Returns: - The kwargs passed down to the hooks. - """ - return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, None) diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 13a86a93a51c6..686ed362399ef 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -13,17 +13,17 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Callable, Dict, Optional, OrderedDict, Union import torch from torch import Tensor from torch.optim import Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.loops import _Loop +from pytorch_lightning.loops.loop import _Loop from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult from pytorch_lightning.loops.progress import OptimizationProgress -from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs +from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import WarningCache from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -141,91 +141,26 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: return self._result.loss -_OUTPUTS_TYPE = Dict[int, Dict[str, Any]] +_OUTPUTS_TYPE = Dict[str, Any] class _OptimizerLoop(_Loop): - """Iterates over one or multiple optimizers and for each one it calls the - :meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index - and the optimizer index if multiple optimizers are requested. - - It is the leaf node in the tree of loops and performs automatic optimization - (forward, zero grad, backward, optimizer step). - """ + """Performs automatic optimization (forward, zero grad, backward, optimizer step)""" output_result_cls = ClosureResult def __init__(self) -> None: super().__init__() self.optim_progress: OptimizationProgress = OptimizationProgress() - - self._outputs: _OUTPUTS_TYPE = {} self._skip_backward: bool = False - self._optimizers: Tuple[Optimizer, ...] = tuple() - self._indices: Tuple[int, ...] = tuple() - - @property - def optimizer_idx(self) -> int: - return self._indices[self.optim_progress.optimizer_position] - - @property - def done(self) -> bool: - """Returns ``True`` when the last optimizer in the sequence has run.""" - return self.optim_progress.optimizer_position >= len(self._indices) - - def run(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> _OUTPUTS_TYPE: - self.reset() - self.on_run_start(optimizers) - while not self.done: - try: - self.advance(kwargs) - self._restarting = False - except StopIteration: - break - self._restarting = False - return self.on_run_end() - - def reset(self) -> None: - if not self.restarting: - # when reset() is called from outside (manually), we reset the loop progress - self.optim_progress.optimizer_position = 0 - else: - self.optim_progress.reset_on_restart() - self._outputs = {} - - def on_run_start(self, optimizers: List[Tuple[int, Optimizer]]) -> None: - self._indices, self._optimizers = zip(*optimizers) - if self.done: - self.optim_progress.optimizer_position = 0 - - def advance(self, kwargs: OrderedDict) -> None: - kwargs = self._build_kwargs(kwargs, self.optimizer_idx) - - result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) - if result.loss is not None: - # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch - # would be skipped otherwise - self._outputs[self.optimizer_idx] = result.asdict() - self.optim_progress.optimizer_position += 1 - - def on_run_end(self) -> _OUTPUTS_TYPE: - outputs, self._outputs = self._outputs, {} # free memory - self._indices = tuple() - self._optimizers = tuple() - return outputs - - def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimizer) -> ClosureResult: + + def run(self, optimizer: Optimizer, kwargs: OrderedDict) -> _OUTPUTS_TYPE: """Runs closure (train step + backward) together with optimization if necessary. Args: - kwargs: the kwargs passed down to the hooks. - optimizer: the current optimizer + kwargs: the kwargs passed down to the hooks + optimizer: the optimizer """ - opt_idx = kwargs.get("optimizer_idx", 0) - - # toggle model params - self._run_optimization_start(opt_idx, optimizer) - closure = self._make_closure(kwargs, optimizer) if ( @@ -247,28 +182,26 @@ def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimize # ------------------------------ # gradient update with accumulated gradients else: - self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure) + self._optimizer_step(optimizer, kwargs.get("batch_idx", 0), closure) result = closure.consume_result() - - # untoggle model params - self._run_optimization_end(optimizer) - return result + if result.loss is None: + return {} + return result.asdict() def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer) -> Closure: """Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`.""" - opt_idx = kwargs.get("optimizer_idx", 0) step_fn = self._make_step_fn(kwargs) - backward_fn = self._make_backward_fn(optimizer, opt_idx) - zero_grad_fn = self._make_zero_grad_fn(kwargs.get("batch_idx", 0), opt_idx, optimizer) + backward_fn = self._make_backward_fn(optimizer) + zero_grad_fn = self._make_zero_grad_fn(kwargs.get("batch_idx", 0), optimizer) return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn) def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, kwargs) - def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: + def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: """Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. @@ -283,11 +216,11 @@ def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) def zero_grad_fn() -> None: self._on_before_zero_grad(optimizer) - self._optimizer_zero_grad(batch_idx, optimizer, opt_idx) + self._optimizer_zero_grad(batch_idx, optimizer) return zero_grad_fn - def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Callable[[Tensor], None]]: + def _make_backward_fn(self, optimizer: Optimizer) -> Optional[Callable[[Tensor], None]]: """Build a `backward` function that handles back-propagation through the output produced by the `training_step` function. @@ -297,32 +230,13 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call return None def backward_fn(loss: Tensor) -> None: - self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx) + self.trainer._call_strategy_hook("backward", loss, optimizer) return backward_fn - def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: - """Toggles the optimizer to ensure the correct one is used and prevent dangling grads. - - Args: - opt_idx: the index of the optimizer to use - optimizer: the optimizer to use - """ - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.toggle_optimizer(optimizer) - - def _run_optimization_end(self, optimizer: Optimizer) -> None: - if len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.untoggle_optimizer(optimizer) - def _optimizer_step( self, optimizer: Union[Optimizer, LightningOptimizer], - opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable[[], Optional[Tensor]], ) -> None: @@ -330,13 +244,12 @@ def _optimizer_step( Args: optimizer: the optimizer to perform the step with - opt_idx: the index of the current :param:`optimizer` batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default, called by the optimizer (if possible) """ # wraps into LightningOptimizer only for running step - optimizer = self.trainer.strategy._lightning_optimizers[opt_idx] + optimizer = self.trainer.strategy._lightning_optimizers[0] # if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we # need to check again if `should_accumulate` before increasing the counters @@ -350,7 +263,6 @@ def _optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, - opt_idx, train_step_and_backward_closure, ) @@ -368,16 +280,15 @@ def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: self.trainer._call_lightning_module_hook("on_before_zero_grad", optimizer) self.optim_progress.optimizer.zero_grad.increment_started() - def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: + def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer) -> None: """Zeroes out all gradients of parameters optimized by the current optimizer. Args: batch_idx: the index of the current batch optimizer: the current optimizer - opt_idx: the index of the current optimizer """ self.trainer._call_lightning_module_hook( - "optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx + "optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer ) self.optim_progress.optimizer.zero_grad.increment_completed() @@ -401,17 +312,4 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult: result = self.output_result_cls.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches ) - return result - - def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int) -> OrderedDict: - """Helper method to build the arguments for the current step. - - Args: - kwargs: The kwargs passed down to the hooks. - opt_idx: the index of the current optimizer. - - Returns: - The kwargs passed down to the hooks. - """ - return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx) diff --git a/src/pytorch_lightning/loops/progress.py b/src/pytorch_lightning/loops/progress.py index 8292178a06f11..2a0e687e5bf3d 100644 --- a/src/pytorch_lightning/loops/progress.py +++ b/src/pytorch_lightning/loops/progress.py @@ -259,15 +259,9 @@ class OptimizationProgress(BaseProgress): Args: optimizer: Tracks optimizer progress. - optimizer_position: The index of the current optimizer amongst the currently active optimizers. - Used to know which optimizer we were using when restarting. - Since not all optimizers may be active at a given time, this index is different from the ``optimizer_idx`` - seen in the optimization loops. """ - # TODO: support for multiple optimizers optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) - optimizer_position: int = 0 @property def optimizer_steps(self) -> int: @@ -275,15 +269,12 @@ def optimizer_steps(self) -> int: def reset(self) -> None: self.optimizer.reset() - self.optimizer_position = 0 def reset_on_run(self) -> None: self.optimizer.reset_on_run() - self.optimizer_position = 0 def reset_on_restart(self) -> None: self.optimizer.reset_on_restart() def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) - self.optimizer_position = state_dict["optimizer_position"] diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index d0187697aaf7e..c8d61fc35783b 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict from contextlib import contextmanager -from functools import lru_cache -from typing import Generator, List, Optional, Sequence, Tuple, Union +from typing import Generator, Optional, Tuple, Union -import numpy as np import torch from torch import Tensor -from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -31,7 +27,6 @@ from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature def check_finite_loss(loss: Optional[Tensor]) -> None: @@ -86,43 +81,6 @@ def _parse_loop_limits( return min_epochs, max_epochs -def _build_training_step_kwargs( - kwargs: OrderedDict, - lightning_module: "pl.LightningModule", - optimizers: Sequence[Optimizer], - opt_idx: Optional[int], -) -> OrderedDict: - """Builds the keyword arguments for training_step. - - Args: - kwargs: The kwargs passed down to the hooks. - lightning_module: the LightningModule with a `training_step` hook implementation - optimizers: the list of optimizers from the Trainer - opt_idx: the index of the current optimizer - - Returns: - the keyword arguments for the training step - """ - training_step_fx = getattr(lightning_module, "training_step") - if len(optimizers) > 1: - has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") - if has_opt_idx_in_train_step: - if not lightning_module.automatic_optimization: - raise ValueError( - "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" - " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" - " argument or set `self.automatic_optimization = True`." - ) - kwargs["optimizer_idx"] = opt_idx - elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: - raise ValueError( - f"Your LightningModule defines {len(optimizers)} optimizers but" - " `training_step` is missing the `optimizer_idx` argument." - ) - - return kwargs - - @contextmanager def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]: """Blocks synchronization in :class:`~pytorch_lightning.strategies.parallel.ParallelStrategy`. This is useful @@ -142,33 +100,6 @@ def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Gen yield None -@lru_cache(1) -def _cumulative_optimizer_frequencies(frequencies: Tuple[int]) -> np.ndarray: - return np.cumsum(frequencies) - - -def _get_active_optimizers( - optimizers: List[Optimizer], frequencies: List[int], batch_idx: int -) -> List[Tuple[int, Optimizer]]: - """Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only - one of the optimizers is active at a time. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - if not frequencies: - # call training_step once per optimizer - return list(enumerate(optimizers)) - - freq_cumsum = _cumulative_optimizer_frequencies(tuple(frequencies)) - optimizers_loop_length = freq_cumsum[-1] - current_place_in_loop = batch_idx % optimizers_loop_length - - # find optimizer index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right") - return [(opt_idx, optimizers[opt_idx])] - - def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: """Check if the limit has been reached (if enabled). diff --git a/src/pytorch_lightning/plugins/precision/amp.py b/src/pytorch_lightning/plugins/precision/amp.py index 224c32959e662..f5262a5c7c8f2 100644 --- a/src/pytorch_lightning/plugins/precision/amp.py +++ b/src/pytorch_lightning/plugins/precision/amp.py @@ -55,19 +55,14 @@ def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler - return super().optimizer_step( - optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs - ) + return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs) if isinstance(optimizer, LBFGS): - raise MisconfigurationException( - f"AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." - ) + raise MisconfigurationException("AMP and the LBFGS optimizer are not compatible.") closure_result = closure() if not _optimizer_handles_unscaling(optimizer): @@ -76,7 +71,7 @@ def optimizer_step( # type: ignore[override] # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. self.scaler.unscale_(optimizer) - self._after_closure(model, optimizer, optimizer_idx) + self._after_closure(model, optimizer) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not model.automatic_optimization or not skipped_backward: diff --git a/src/pytorch_lightning/plugins/precision/colossalai.py b/src/pytorch_lightning/plugins/precision/colossalai.py index 5baa67044acf9..1bb8a95416e5a 100644 --- a/src/pytorch_lightning/plugins/precision/colossalai.py +++ b/src/pytorch_lightning/plugins/precision/colossalai.py @@ -48,7 +48,6 @@ def backward( # type: ignore[override] tensor: Tensor, model: "pl.LightningModule", optimizer: Optional[Steppable], - optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> None: @@ -65,12 +64,11 @@ def optimizer_step( # type: ignore[override] self, optimizer: Steppable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: closure_result = closure() - self._after_closure(model, optimizer, optimizer_idx) + self._after_closure(model, optimizer) skipped_backward = closure_result is None if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward: raise ValueError( diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 61a88fa6a9fcd..d2a824a3fbaa9 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -60,7 +60,6 @@ def backward( # type: ignore[override] tensor: Tensor, model: "pl.LightningModule", optimizer: Optional[Steppable], - optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> None: @@ -70,7 +69,6 @@ def backward( # type: ignore[override] tensor: the loss tensor model: the model to be optimized optimizer: ignored for DeepSpeed - optimizer_idx: ignored for DeepSpeed \*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call \**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call """ @@ -86,16 +84,13 @@ def optimizer_step( # type: ignore[override] self, optimizer: Steppable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: if isinstance(optimizer, LBFGS): - raise MisconfigurationException( - f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." - ) + raise MisconfigurationException("DeepSpeed and the LBFGS optimizer are not compatible.") closure_result = closure() - self._after_closure(model, optimizer, optimizer_idx) + self._after_closure(model, optimizer) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if model.automatic_optimization and skipped_backward: diff --git a/src/pytorch_lightning/plugins/precision/ipu.py b/src/pytorch_lightning/plugins/precision/ipu.py index e3494c42c6efa..632474ec8408f 100644 --- a/src/pytorch_lightning/plugins/precision/ipu.py +++ b/src/pytorch_lightning/plugins/precision/ipu.py @@ -66,17 +66,14 @@ def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: """IPUs handle the optimizer step internally.""" if isinstance(optimizer, LBFGS): - raise MisconfigurationException( - f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." - ) + raise MisconfigurationException("IPUs and the LBFGS optimizer are not compatible.") closure_result = closure() - self._after_closure(model, optimizer, optimizer_idx) + self._after_closure(model, optimizer) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if model.automatic_optimization and skipped_backward: diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index f4d4ab2bde8ed..2afabc6acfa17 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -49,7 +49,6 @@ def backward( # type: ignore[override] tensor: Tensor, model: "pl.LightningModule", optimizer: Optional[Steppable], - optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> None: @@ -59,12 +58,11 @@ def backward( # type: ignore[override] tensor: the loss value obtained from the closure model: the model to be optimized optimizer: current optimizer being used. ``None`` if using manual optimization - optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization \*args: Positional arguments intended for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. """ - model.backward(tensor, optimizer, optimizer_idx, *args, **kwargs) + model.backward(tensor, *args, **kwargs) def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] # once backward has been applied, release graph @@ -73,18 +71,15 @@ def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: module.trainer._call_lightning_module_hook("on_after_backward") return closure_loss - def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable, optimizer_idx: int) -> None: + def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> None: """Utility to share some code after the closure has been run.""" trainer = model.trainer - trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx) - trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx) - # TODO: this is done for the entire model but should be changed to per-optimizer - if optimizer_idx == 0: - self._track_grad_norm(trainer) + trainer._call_callback_hooks("on_before_optimizer_step", optimizer) + trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer) + self._track_grad_norm(trainer) self._clip_gradients( model, optimizer, - optimizer_idx, trainer.gradient_clip_val, gradient_clip_algorithm=trainer.gradient_clip_algorithm, ) @@ -93,7 +88,6 @@ def _wrap_closure( self, model: "pl.LightningModule", optimizer: Optimizer, - optimizer_idx: int, closure: Callable[[], Any], ) -> Any: """This double-closure allows makes sure the ``closure`` is executed before the @@ -103,19 +97,18 @@ def _wrap_closure( consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """ closure_result = closure() - self._after_closure(model, optimizer, optimizer_idx) + self._after_closure(model, optimizer) return closure_result def optimizer_step( # type: ignore[override] self, optimizer: Steppable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: """Hook to run the optimizer step.""" - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure = partial(self._wrap_closure, model, optimizer, closure) return optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: @@ -137,7 +130,6 @@ def _clip_gradients( self, model: Union["pl.LightningModule", Module], optimizer: Steppable, - optimizer_idx: int, clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, ) -> None: @@ -148,7 +140,6 @@ def _clip_gradients( model.trainer._call_lightning_module_hook( "configure_gradient_clipping", optimizer, - optimizer_idx, gradient_clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm, ) diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py index 93ef4cb87b0d6..d8b96274ded9c 100644 --- a/src/pytorch_lightning/plugins/precision/tpu.py +++ b/src/pytorch_lightning/plugins/precision/tpu.py @@ -40,14 +40,13 @@ def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, model: "pl.LightningModule", - optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: import torch_xla.core.xla_model as xm closure = partial(self._tpu_wrap_closure, optimizer, closure) - closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) + closure = partial(self._wrap_closure, model, optimizer, closure) closure_result = optimizer.step(closure=closure, **kwargs) xm.mark_step() skipped_backward = closure_result is None diff --git a/src/pytorch_lightning/strategies/colossalai.py b/src/pytorch_lightning/strategies/colossalai.py index 3232d114efe6e..e7a9877859e56 100644 --- a/src/pytorch_lightning/strategies/colossalai.py +++ b/src/pytorch_lightning/strategies/colossalai.py @@ -399,7 +399,6 @@ def teardown(self) -> None: def optimizer_step( self, optimizer: Optimizer, - opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, @@ -407,9 +406,7 @@ def optimizer_step( model = model or self.lightning_module # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed assert isinstance(model, pl.LightningModule) - return self.precision_plugin.optimizer_step( - optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs - ) + return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) def lightning_module_state_dict(self, rank_zero_only: bool = False) -> Dict[str, Any]: """Returns a dictionary containing a whole state of the module. But all the tensors in the dictionary are diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 4f90a7a924082..70caa8c9f520e 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -240,7 +240,6 @@ def _enable_model_averaging(self) -> None: def optimizer_step( self, optimizer: Optimizer, - opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, @@ -249,12 +248,11 @@ def optimizer_step( Args: optimizer: the optimizer performing the step - opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` """ - optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) if self._model_averager is None: return optimizer_output diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e55d89aea9284..830d052271daa 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -456,18 +456,14 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig], Optional[int]]: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: assert self.lightning_module is not None - optimizers, lr_schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) + optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) - return ( - optimizers[0], - lr_schedulers[0] if lr_schedulers else None, - optimizer_frequencies[0] if optimizer_frequencies else None, - ) + return optimizers[0], lr_schedulers[0] if lr_schedulers else None @property def zero_stage_3(self) -> bool: @@ -485,7 +481,10 @@ def _initialize_deepspeed_train(self, model: Module) -> None: ) lr_scheduler = None else: - optimizer, lr_scheduler, _ = self._init_optimizers() + ( + optimizer, + lr_scheduler, + ) = self._init_optimizers() if lr_scheduler is not None: scheduler = lr_scheduler.scheduler @@ -500,7 +499,7 @@ def _initialize_deepspeed_train(self, model: Module) -> None: # disable deepspeed lr scheduling as lightning manages scheduling model.lr_scheduler = None if lr_scheduler is None: - lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step", opt_idx=0) + lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step") else: lr_scheduler.scheduler = deepspeed_scheduler self.lr_scheduler_configs = [lr_scheduler] @@ -587,10 +586,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` - # empty optimizers, schedulers and frequencies + # empty optimizers, schedulers self.optimizers = [] self.lr_scheduler_configs = [] - self.optimizer_frequencies = [] @property def handles_gradient_accumulation(self) -> bool: diff --git a/src/pytorch_lightning/strategies/hpu_parallel.py b/src/pytorch_lightning/strategies/hpu_parallel.py index c10af58fab657..3fb50d6581f7c 100644 --- a/src/pytorch_lightning/strategies/hpu_parallel.py +++ b/src/pytorch_lightning/strategies/hpu_parallel.py @@ -145,12 +145,11 @@ def on_after_backward(self) -> None: def optimizer_step( self, optimizer: Optimizer, - opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: - optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) # Break lazy accumulation of graph after optimizer htcore.mark_step() return optimizer_output diff --git a/src/pytorch_lightning/strategies/single_hpu.py b/src/pytorch_lightning/strategies/single_hpu.py index f05c7be15327c..a52b303b0882c 100644 --- a/src/pytorch_lightning/strategies/single_hpu.py +++ b/src/pytorch_lightning/strategies/single_hpu.py @@ -89,12 +89,11 @@ def on_after_backward(self) -> None: def optimizer_step( self, optimizer: Optimizer, - opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: - optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) # Break lazy accumulation of graph after optimizer htcore.mark_step() return optimizer_output diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 3b490eaee9aae..7d0a39e03c813 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -67,7 +67,6 @@ def __init__( self._optimizers: List[Optimizer] = [] self._lightning_optimizers: Dict[int, LightningOptimizer] = {} self.lr_scheduler_configs: List[LRSchedulerConfig] = [] - self.optimizer_frequencies: List[int] = [] @property def launcher(self) -> Optional[_Launcher]: @@ -139,9 +138,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: if trainer.state.fn != TrainerFn.FITTING: return assert self.lightning_module is not None - self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( - self.lightning_module - ) + self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module) def setup(self, trainer: "pl.Trainer") -> None: """Setup plugins for the trainer fit and creates optimizers. @@ -186,7 +183,6 @@ def backward( self, closure_loss: Tensor, optimizer: Optional[Optimizer], - optimizer_idx: Optional[int], *args: Any, **kwargs: Any, ) -> Tensor: @@ -195,7 +191,6 @@ def backward( Args: closure_loss: a tensor holding the loss value to backpropagate optimizer: An optional optimizer that gets passed down to the precision plugin's backward - optimizer_idx: An optional optimizer index that gets passed down to the precision plugin's backward \*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. @@ -204,7 +199,7 @@ def backward( assert self.lightning_module is not None closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module) - self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs) + self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs) closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module) self.post_backward(closure_loss) @@ -214,7 +209,6 @@ def backward( def optimizer_step( self, optimizer: Optimizer, - opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, @@ -223,17 +217,14 @@ def optimizer_step( Args: optimizer: the optimizer performing the step - opt_idx: index of the current optimizer closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks - \**kwargs: Keyword arguments to to ``optimizer.step`` + \**kwargs: Keyword arguments to ``optimizer.step`` """ model = model or self.lightning_module # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed assert isinstance(model, pl.LightningModule) - return self.precision_plugin.optimizer_step( - optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs - ) + return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 233e645325a9e..29eeb8e4ed251 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1558,14 +1558,6 @@ def optimizers(self, new_optims: List[Optimizer]) -> None: def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs - @property - def optimizer_frequencies(self) -> List[int]: - return self.strategy.optimizer_frequencies - - @optimizer_frequencies.setter - def optimizer_frequencies(self, new_freqs: List[int]) -> None: - self.strategy.optimizer_frequencies = new_freqs - @property def precision(self) -> _PRECISION_INPUT_STR: return self.strategy.precision_plugin.precision diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 1075bdacdb526..bdc11f9285bdd 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -106,7 +106,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: # TODO: update docs here """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified optimizer together with a new scheduler that takes care of the learning rate search.""" - from pytorch_lightning.core.optimizer import _set_scheduler_opt_idx + from pytorch_lightning.core.optimizer import _validate_optimizers_attached optimizers = trainer.strategy.optimizers @@ -128,8 +128,8 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: scheduler = cast(LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] - _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) + trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")] + _validate_optimizers_attached(trainer.optimizers, trainer.lr_scheduler_configs) def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]: """Plot results from lr_find run @@ -303,7 +303,6 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return { "optimizers": trainer.strategy.optimizers, "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs, - "optimizer_frequencies": trainer.strategy.optimizer_frequencies, "callbacks": trainer.callbacks, "loggers": trainer.loggers, "max_steps": trainer.fit_loop.max_steps, @@ -316,7 +315,6 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto from pytorch_lightning.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] - trainer.strategy.optimizer_frequencies = [] # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging @@ -329,7 +327,6 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: trainer.strategy.optimizers = params["optimizers"] trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"] - trainer.strategy.optimizer_frequencies = params["optimizer_frequencies"] trainer.callbacks = params["callbacks"] trainer.loggers = params["loggers"] trainer.fit_loop.max_steps = params["max_steps"] diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index 1a5bb038e6ce9..0a27af2ae0f03 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -46,7 +46,11 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], "1.6.5": [_migrate_loop_batches_that_stepped], "1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default], - "2.0.0": [_drop_apex_amp_state, _migrate_loop_structure_after_tbptt_removal], + "2.0.0": [ + _drop_apex_amp_state, + _migrate_loop_structure_after_tbptt_removal, + _migrate_loop_structure_after_optimizer_loop_removal, + ], } @@ -227,8 +231,8 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE became the children of the training epoch loop. Version: 2.0.0 - Commit: TBD - PR: #16172 + Commit: 7807454 + PR: #16337, #16172 """ if "loops" not in checkpoint: return checkpoint @@ -256,3 +260,21 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE fit_loop.pop("epoch_loop.batch_loop.state_dict", None) return checkpoint + + +def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Adjusts the loop structure since it changed when the support for multiple optimizers in automatic + optimization mode was removed. There is no longer a loop over optimizer, and hence no position to store for + resuming the loop. + + Version: 2.0.0 + Commit: TBD + PR: TBD + """ + if "loops" not in checkpoint: + return checkpoint + + # TODO: Complete this migration function when optimizer loop gets flattened out and keys need to be remapped + fit_loop = checkpoint["loops"]["fit_loop"] + fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None) + return checkpoint diff --git a/src/pytorch_lightning/utilities/signature_utils.py b/src/pytorch_lightning/utilities/signature_utils.py index 416b677f1765b..55a72498f07c9 100644 --- a/src/pytorch_lightning/utilities/signature_utils.py +++ b/src/pytorch_lightning/utilities/signature_utils.py @@ -23,7 +23,7 @@ def is_param_in_hook_signature( hook_fx: the hook callable param: the name of the parameter to check explicit: whether the parameter has to be explicitly declared - min_args: whether the `signature` as at least `min_args` parameters + min_args: whether the `signature` has at least `min_args` parameters """ parameters = inspect.getfullargspec(hook_fx) args = parameters.args[1:] # ignore `self` diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 672de4115c422..1e9ea753d78cf 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -130,5 +130,3 @@ class LRSchedulerConfig: monitor: Optional[str] = None # enforce that the monitor exists for ReduceLROnPlateau strict: bool = True - # opt_idx assigned internally if not assigned by user - opt_idx: Optional[int] = None diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 08853b5d7ae26..bebbb43874d40 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -546,6 +546,9 @@ def configure_optimizers(self): return [torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters())] model = TestModel() + # Must switch to manual optimization mode, otherwise we would get a different error + # (multiple optimizers only supported with manual optimization) + model.automatic_optimization = False trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1) with pytest.raises(MisconfigurationException, match="IPUs currently only support one optimizer."): diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 773d72049641e..42b721e509297 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -75,7 +75,7 @@ def train_dataloader(self): class TestBackboneFinetuningWarningCallback(BackboneFinetuning): - def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): + def finetune_function(self, pl_module, epoch: int, optimizer): """Called when the epoch begins.""" if epoch == 0: @@ -211,7 +211,7 @@ class OnEpochLayerFinetuning(BaseFinetuning): def freeze_before_training(self, pl_module: LightningModule): self.freeze(pl_module.layer) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer): self.unfreeze_and_add_param_group(pl_module.layer[epoch + 1], optimizer) @@ -316,7 +316,7 @@ class TestCallbacksRestoreCallback(BaseFinetuning): def freeze_before_training(self, pl_module): self.freeze(pl_module.layer[:3]) - def finetune_function(self, pl_module, epoch, optimizer, opt_idx): + def finetune_function(self, pl_module, epoch, optimizer): if epoch >= 1: self.unfreeze_and_add_param_group(pl_module.layer[epoch - 1], optimizer) diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 1a54dd15dec84..b9a06389668d1 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -538,7 +538,7 @@ def freeze_before_training(self, pl_module): self.freeze(pl_module.backbone[1]) self.freeze(pl_module.layer) - def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): + def finetune_function(self, pl_module, epoch: int, optimizer): """Called when the epoch begins.""" if epoch == 1 and isinstance(optimizer, torch.optim.SGD): self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 2381b7944399c..d96578b597a73 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -358,7 +358,7 @@ class TestModel(BoringModel): has_validated_gradients = False custom_gradient_clip_val = 1e-2 - def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): assert gradient_clip_val == self.trainer.gradient_clip_val assert gradient_clip_algorithm == self.trainer.gradient_clip_algorithm @@ -387,7 +387,7 @@ def test_lightning_module_configure_gradient_clipping_different_argument_values( class TestModel(BoringModel): custom_gradient_clip_val = 1e-2 - def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): self.clip_gradients(optimizer, gradient_clip_val=self.custom_gradient_clip_val) model = TestModel() @@ -403,7 +403,7 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va class TestModel(BoringModel): custom_gradient_clip_algorithm = "foo" - def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): self.clip_gradients(optimizer, gradient_clip_algorithm=self.custom_gradient_clip_algorithm) model = TestModel() diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 75483e62b82cb..4fc96870155d8 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -169,7 +169,7 @@ class TestModel(BoringModel): def training_epoch_end(self, outputs): ... - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): if batch_idx % 2 == 0: optimizer.zero_grad() @@ -198,7 +198,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): def training_epoch_end(self, outputs): ... - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **_): assert isinstance(optimizer_closure, Closure) # zero_grad is called inside the closure optimizer_closure() @@ -315,7 +315,7 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): return SGD(self.layer.parameters(), lr=0.1) - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **__): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **__): # check attributes are accessible assert all("lr" in pg for pg in optimizer.param_groups) assert optimizer.state is optimizer._optimizer.state diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index fff8445f618dd..83e1734ff02a1 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -110,14 +110,14 @@ def configure_optimizers__lr_on_plateau_step(self): scheduler = {"scheduler": lr_scheduler, "interval": "step", "monitor": "pbar_acc1"} return [optimizer], [scheduler] - def backward(self, loss, optimizer, optimizer_idx): + def backward(self, loss, *args, **kwargs): if self.assert_backward: if self.trainer.precision == "16": assert loss > 171 * 1000 else: assert loss == 171.0 - super().backward(loss, optimizer, optimizer_idx) + return super().backward(loss, *args, **kwargs) class DummyDataset(Dataset): diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index 6fefe498e7554..aa870f53fc17d 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -16,122 +16,9 @@ import pytest -from pytorch_lightning import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.loops import _TrainingEpochLoop from pytorch_lightning.trainer.trainer import Trainer -_out00 = {"loss": 0.0} -_out01 = {"loss": 0.1} -_out02 = {"loss": 0.2} -_out03 = {"loss": 0.3} -_out10 = {"loss": 1.0} -_out11 = {"loss": 1.1} -_out12 = {"loss": 1.2} -_out13 = {"loss": 1.3} - - -class TestPrepareOutputs: - def prepare_outputs(self, fn, batch_outputs, num_optimizers, automatic_optimization): - lightning_module = LightningModule() - lightning_module.automatic_optimization = automatic_optimization - return fn( - batch_outputs, - lightning_module=lightning_module, - num_optimizers=num_optimizers, # does not matter for manual optimization - ) - - def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): - return self.prepare_outputs( - _TrainingEpochLoop._prepare_outputs_training_epoch_end, - batch_outputs, - num_optimizers, - automatic_optimization=automatic_optimization, - ) - - def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): - return self.prepare_outputs( - _TrainingEpochLoop._prepare_outputs_training_batch_end, - batch_outputs, - num_optimizers, - automatic_optimization=automatic_optimization, - ) - - @pytest.mark.parametrize( - "num_optimizers,batch_outputs,expected", - [ - (1, [], []), - (1, [[]], []), - # 1 batch - (1, [[{0: _out00}]], [_out00]), - # 2 batches - (1, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]), - # 1 batch, 2 optimizers - (2, [[{0: _out00, 1: _out01}]], [_out00, _out01]), - # 2 batches, 2 optimizers - (2, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]), - # 4 batches, 2 optimizers, different frequency - ( - 2, - [[{0: _out00}], [{1: _out10}], [{1: _out11}], [{0: _out01}]], - [[_out00], [_out10], [_out11], [_out01]], - ), - ], - ) - def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, batch_outputs, expected): - """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook - currently expects in the case of automatic optimization.""" - assert self.prepare_outputs_training_epoch_end(batch_outputs, num_optimizers) == expected - - @pytest.mark.parametrize( - "batch_outputs,expected", - [ - ([], []), - ([[]], []), - # 1 batch - ([[_out00]], [_out00]), - # 2 batches - ([[_out00], [_out01]], [_out00, _out01]), - # skipped outputs - ([[_out00], [], [], [_out03]], [_out00, _out03]), - ], - ) - def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected): - """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook - currently expects in the case of manual optimization.""" - assert self.prepare_outputs_training_epoch_end(batch_outputs, -1, automatic_optimization=False) == expected - - @pytest.mark.parametrize( - "num_optimizers,batch_end_outputs,expected", - [ - (1, [], []), - (1, [[]], []), - # 1 optimizer - (1, [{0: _out00}], _out00), - # 2 optimizers - (2, [{0: _out00, 1: _out01}], [_out00, _out01]), - ], - ) - def test_prepare_outputs_training_batch_end_automatic(self, num_optimizers, batch_end_outputs, expected): - """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook - currently expects in the case of automatic optimization.""" - - assert self.prepare_outputs_training_batch_end(batch_end_outputs, num_optimizers) == expected - - @pytest.mark.parametrize( - "batch_end_outputs,expected", - [ - ([], []), - ([[]], []), - # skipped outputs - ([_out00, None, _out02], [_out00, _out02]), - ], - ) - def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expected): - """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook - currently expects in the case of manual optimization.""" - assert self.prepare_outputs_training_batch_end(batch_end_outputs, -1, automatic_optimization=False) == expected - def test_no_val_on_train_epoch_loop_restart(tmpdir): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" diff --git a/tests/tests_pytorch/loops/optimization/test_closure.py b/tests/tests_pytorch/loops/optimization/test_closure.py index c5de071766f15..638975e8c6793 100644 --- a/tests/tests_pytorch/loops/optimization/test_closure.py +++ b/tests/tests_pytorch/loops/optimization/test_closure.py @@ -21,9 +21,7 @@ def test_optimizer_step_no_closure_raises(tmpdir): class TestModel(BoringModel): - def optimizer_step( - self, epoch=None, batch_idx=None, optimizer=None, optimizer_idx=None, optimizer_closure=None, **_ - ): + def optimizer_step(self, epoch=None, batch_idx=None, optimizer=None, optimizer_closure=None, **_): # does not call `optimizer_closure()` pass diff --git a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py index 27826d5fd88eb..b0090e5c96160 100644 --- a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock import pytest import torch -from torch.optim import Adam, SGD -from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import OnExceptionCheckpoint -from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -65,169 +61,5 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@pytest.mark.parametrize( - "frequencies,expected", - [ - ( - (3, 1), - [ - (0, "SGD"), - (0, "SGD"), - (0, "SGD"), - (1, "Adam"), - (0, "SGD"), - (0, "SGD"), - (0, "SGD"), - (1, "Adam"), - (0, "SGD"), - (0, "SGD"), - ], - ), - ( - (1, 2), - [ - (0, "SGD"), - (1, "Adam"), - (1, "Adam"), - (0, "SGD"), - (1, "Adam"), - (1, "Adam"), - (0, "SGD"), - (1, "Adam"), - (1, "Adam"), - (0, "SGD"), - ], - ), - ], -) -def test_optimizer_frequencies(tmpdir, frequencies, expected): - """Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different - frequencies are requested.""" - - class CurrentModel(BoringModel): - def training_step(self, batch, batch_idx, optimizer_idx): - return super().training_step(batch, batch_idx) - - def configure_optimizers(self): - opt0 = SGD(self.parameters(), lr=0.1) - opt1 = Adam(self.parameters(), lr=0.1) - return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]} - - model = CurrentModel() - model.training_epoch_end = None - model.optimizer_step = Mock(wraps=model.optimizer_step) - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=10, - enable_progress_bar=False, - ) - trainer.fit(model) - - positional_args = [c[0] for c in model.optimizer_step.call_args_list] - pl_optimizer_sequence = [args[2] for args in positional_args] - opt_idx_sequence = [args[3] for args in positional_args] - assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence) - optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence] - assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected - - class CustomException(Exception): pass - - -@pytest.mark.parametrize("stop_epoch", (0, 1)) -@pytest.mark.parametrize("stop_batch", (0, 1, 2)) -@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)]) -def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch): - """Test that Lightning can resume from a point where a training_step failed while in the middle of processing - several optimizer steps for one batch. - - The test asserts that we end up with the same trained weights as if no failure occurred. - """ - - n_batches = 3 - n_epochs = 2 - - def _assert_optimizer_sequence(method_mock, expected): - positional_args = [c[0] for c in method_mock.call_args_list] - sequence = [arg[3] for arg in positional_args] - assert sequence == expected - - num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer - - opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...] - # +1 because we fail inside the closure inside optimizer_step() - opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)] - opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:] - - class MultipleOptimizerModel(BoringModel): - def training_step(self, batch, batch_idx, optimizer_idx): - if ( - fail - and self.current_epoch == stop_epoch - and batch_idx == stop_batch - and optimizer_idx == stop_optimizer - ): - raise CustomException - return super().training_step(batch, batch_idx) - - def configure_optimizers(self): - return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)] - - # run without a failure, collect weights - fail = False - seed_everything(0) - model = MultipleOptimizerModel() - model.training_epoch_end = None - model.optimizer_step = Mock(wraps=model.optimizer_step) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=n_epochs, - limit_train_batches=n_batches, - limit_val_batches=0, - num_sanity_val_steps=0, - logger=False, - enable_checkpointing=False, - ) - trainer.fit(model) - model.parameters() - _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete) - - # simulate a failure - fail = True - seed_everything(0) - model = MultipleOptimizerModel() - model.training_epoch_end = None - model.optimizer_step = Mock(wraps=model.optimizer_step) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=n_epochs, - limit_train_batches=n_batches, - limit_val_batches=0, - num_sanity_val_steps=0, - logger=False, - callbacks=OnExceptionCheckpoint(tmpdir), - ) - with pytest.raises(CustomException): - trainer.fit(model) - - _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete) - - # resume from failure and collect weights - fail = False - seed_everything(0) - model = MultipleOptimizerModel() - model.training_epoch_end = None - model.optimizer_step = Mock(wraps=model.optimizer_step) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=n_epochs, - limit_train_batches=n_batches, - limit_val_batches=0, - num_sanity_val_steps=0, - logger=False, - enable_checkpointing=False, - ) - trainer.fit(model, ckpt_path=str(tmpdir / "on_exception.ckpt")) - - _assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed) diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index d244d6e08a78f..2a10b07b1edb5 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -40,8 +40,8 @@ def validation_step(self, batch, batch_idx): out = {"something": "random"} return out - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.validation_step_end = None @@ -65,10 +65,8 @@ def backward(self, loss, optimizer, optimizer_idx): # simulate training manually trainer.state.stage = RunningStage.TRAINING kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) - assert len(train_step_out) == 1 - train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 @@ -102,8 +100,8 @@ def validation_step_end(self, out): assert self.last_out == out return out - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.validation_epoch_end = None @@ -126,10 +124,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) - assert len(train_step_out) == 1 - train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 @@ -169,8 +165,8 @@ def validation_epoch_end(self, outputs): assert out_a == self.out_a assert out_b == self.out_b - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.validation_step_end = None @@ -228,8 +224,8 @@ def validation_epoch_end(self, outputs): assert out_a == self.out_a assert out_b == self.out_b - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 364672fe6261e..975de5b6179e1 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -63,7 +63,6 @@ def test_loops_state_dict_structure(): "current": {"ready": 0, "started": 0, "completed": 0}, }, }, - "optimizer_position": 0, }, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 0944c42fe2be1..248e48d02e9d5 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -266,36 +266,18 @@ def val_dataloader(self): @pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) -@pytest.mark.parametrize("n_optimizers", (1, 3, 5)) @pytest.mark.parametrize("stop_epoch", (1, 2)) @pytest.mark.parametrize("stop_batch", (1, 2)) -@pytest.mark.parametrize("stop_optimizer", (1, 2)) -def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): - stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 +def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, tmpdir): n_epochs = 3 n_batches = 3 class TestModel(BoringModel): - def __init__(self): - super().__init__() - if n_optimizers > 1: - self.configure_optimizers = self.configure_optimizers_multiple - - def training_step(self, batch, batch_idx, optimizer_idx=0): - if self.trainer.current_epoch == stop_epoch and batch_idx == stop_batch and optimizer_idx == stop_optimizer: + def training_step(self, batch, batch_idx): + if self.trainer.current_epoch == stop_epoch and batch_idx == stop_batch: raise CustomException return super().training_step(batch, batch_idx) - def configure_optimizers_multiple(self): - optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)] - - lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1) - lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1) - # no scheduler for optimizer_2 - lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] - - return optimizers, lr_schedulers - model = TestModel() model.training_epoch_end = None @@ -324,35 +306,26 @@ def configure_optimizers_multiple(self): # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch nbe_batches_completed = stop_epoch * n_batches be_batches_completed = stop_batch - be_batches_ready = stop_batch + 1 # lightning applies leftover accumulated gradients when the epoch ends has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 # number of batches that will call `optimizer.step()` during non-breaking and breaking epochs nbe_stepping_batches = nbe_batches_completed // accumulate_grad_batches be_stepping_batches = be_batches_completed // accumulate_grad_batches - nbe_total_opt_steps = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - does_last_be_batch_step = be_batches_ready % accumulate_grad_batches == 0 or has_leftover_accumulation_batches - be_total_opt_steps = be_stepping_batches * n_optimizers + does_last_be_batch_step * stop_optimizer + nbe_total_opt_steps = nbe_stepping_batches + has_leftover_accumulation_batches + be_total_opt_steps = be_stepping_batches assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps has_opt_stepped_in_be = stop_batch + 1 >= accumulate_grad_batches - nbe_total_zero_grad = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers - does_last_be_batch_zero_grad = be_batches_completed % accumulate_grad_batches == 0 + nbe_total_zero_grad = nbe_stepping_batches + has_leftover_accumulation_batches # `max` because the first batch always zero-grads - be_total_zero_grad = max(1, be_stepping_batches) * n_optimizers + stop_optimizer * does_last_be_batch_zero_grad + be_total_zero_grad = max(1, be_stepping_batches) assert optim_progress.optimizer.zero_grad.total.completed == nbe_total_zero_grad + be_total_zero_grad assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad nbe_sch_steps = stop_epoch be_sch_steps = 0 # the current epoch did not complete - if n_optimizers > 1: - # assumes that the scheduler config is unchanged - # `* 1` because there is only one step-level scheduler - nbe_sch_steps = stop_epoch + nbe_stepping_batches + has_leftover_accumulation_batches * 1 - # `0 +` for the epoch-level scheduler - be_sch_steps = 0 + be_stepping_batches assert sch_progress.total.completed == nbe_sch_steps + be_sch_steps assert sch_progress.current.completed == be_sch_steps @@ -399,7 +372,6 @@ def configure_optimizers_multiple(self): }, "epoch_loop.optimizer_loop.state_dict": {}, "epoch_loop.optimizer_loop.optim_progress": { - "optimizer_position": stop_optimizer, "optimizer": { "step": { "total": { @@ -442,7 +414,6 @@ def configure_optimizers_multiple(self): # test resetting manually, we expect all `ready` counters to be reset to `completed` trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() - trainer.fit_loop.epoch_loop.optimizer_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress assert epoch_progress.current.ready == stop_epoch @@ -464,31 +435,12 @@ def configure_optimizers_multiple(self): assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch -@pytest.mark.parametrize("n_optimizers", (1, 3, 5)) -def test_loop_state_on_complete_run(n_optimizers, tmpdir): +def test_loop_state_on_complete_run(tmpdir): n_epochs = 3 n_batches = 3 accumulate_grad_batches = 1 class TestModel(BoringModel): - def __init__(self): - super().__init__() - if n_optimizers > 1: - self.configure_optimizers = self.configure_optimizers_multiple - - def training_step(self, batch, batch_idx, optimizer_idx=0): - return super().training_step(batch, batch_idx) - - def configure_optimizers_multiple(self): - optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)] - - lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1) - lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1) - # no scheduler for optimizer_2 - lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] - - return optimizers, lr_schedulers - def train_dataloader(self): # override to test the `is_last_batch` value return DataLoader(RandomDataset(32, n_batches)) @@ -514,9 +466,6 @@ def train_dataloader(self): n_sch_steps_total = n_epochs n_sch_steps_current = 1 - if n_optimizers > 1: - n_sch_steps_total = n_epochs + n_epochs * n_batches - n_sch_steps_current = n_batches + 1 expected = { "state_dict": ANY, @@ -561,28 +510,27 @@ def train_dataloader(self): }, "epoch_loop.optimizer_loop.state_dict": {}, "epoch_loop.optimizer_loop.optim_progress": { - "optimizer_position": n_optimizers, "optimizer": { "step": { "total": { - "ready": n_epochs * n_batches * n_optimizers, - "completed": n_epochs * n_batches * n_optimizers, + "ready": n_epochs * n_batches, + "completed": n_epochs * n_batches, }, "current": { - "ready": n_batches * n_optimizers, - "completed": n_batches * n_optimizers, + "ready": n_batches, + "completed": n_batches, }, }, "zero_grad": { "total": { - "ready": n_epochs * n_batches * n_optimizers, - "started": n_epochs * n_batches * n_optimizers, - "completed": n_epochs * n_batches * n_optimizers, + "ready": n_epochs * n_batches, + "started": n_epochs * n_batches, + "completed": n_epochs * n_batches, }, "current": { - "ready": n_batches * n_optimizers, - "started": n_batches * n_optimizers, - "completed": n_batches * n_optimizers, + "ready": n_batches, + "started": n_batches, + "completed": n_batches, }, }, }, @@ -630,7 +578,6 @@ def test_fit_loop_reset(tmpdir): # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 fit_loop.reset() epoch_loop.reset() - optimizer_loop.reset() assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 @@ -647,7 +594,6 @@ def test_fit_loop_reset(tmpdir): assert epoch_loop.batch_progress.current.completed == 1 assert optimizer_loop.restarting - assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=4.ckpt")) @@ -662,7 +608,6 @@ def test_fit_loop_reset(tmpdir): # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() - optimizer_loop.reset() assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 @@ -678,8 +623,6 @@ def test_fit_loop_reset(tmpdir): assert epoch_loop.batch_progress.current.processed == 3 assert epoch_loop.batch_progress.current.completed == 3 - assert optimizer_loop.optim_progress.optimizer_position == 1 - @pytest.mark.parametrize( ["train_datasets", "val_datasets"], diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py index e82519ad6021f..89c9d7901cef3 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py @@ -30,8 +30,8 @@ def training_step(self, batch, batch_idx): self.training_step_called = True return {"loss": acc, "random_things": [1, "a", torch.tensor(2)]} - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -69,8 +69,8 @@ def training_step_end(self, tr_step_output): self.training_step_end_called = True return tr_step_output - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -116,8 +116,8 @@ def training_epoch_end(self, outputs): assert self.count_num_graphs(b) == 0 assert {"random_things", "loss", "batch_idx"} == set(b.keys()) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -169,8 +169,8 @@ def training_epoch_end(self, outputs): assert self.count_num_graphs(b) == 0 assert {"random_things", "loss", "batch_idx"} == set(b.keys()) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index d9dd5fc341d47..1d4dd2726ca82 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -35,8 +35,8 @@ def training_step(self, batch, batch_idx): self.training_step_called = True return acc - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -74,8 +74,8 @@ def training_step_end(self, tr_step_output): self.training_step_end_called = True return tr_step_output - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -119,8 +119,8 @@ def training_epoch_end(self, outputs): assert "loss" in b assert isinstance(b, dict) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) - assert len(train_step_out) == 1 - train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 @@ -189,8 +187,8 @@ def training_epoch_end(self, outputs): assert "loss" in b assert isinstance(b, dict) - def backward(self, loss, optimizer, optimizer_idx): - return LightningModule.backward(self, loss, optimizer, optimizer_idx) + def backward(self, loss): + return LightningModule.backward(self, loss) model = TestModel() model.val_dataloader = None @@ -217,10 +215,8 @@ def backward(self, loss, optimizer, optimizer_idx): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) - assert len(train_step_out) == 1 - train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 @@ -302,7 +298,7 @@ def training_step(self, batch, batch_idx): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): kwargs = {"batch": batch, "batch_idx": batch_idx} - out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) + out = trainer.fit_loop.epoch_loop.optimizer_loop.run(trainer.optimizers[0], kwargs) if not batch_idx % 2: assert out == {} @@ -328,7 +324,7 @@ def train_dataloader(self): def on_train_batch_end(self, outputs, batch, batch_idx): if batch_idx % 2 == 0: - assert outputs == [] + assert outputs is None else: assert outputs diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 74f7fb15c6116..57ff501c5869f 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -318,16 +318,16 @@ def _auto_train_batch( dict(name="training_step_end", args=(dict(loss=ANY),)), dict(name="Callback.on_before_zero_grad", args=(trainer, model, ANY)), dict(name="on_before_zero_grad", args=(ANY,)), - dict(name="optimizer_zero_grad", args=(current_epoch, i, ANY, 0)), + dict(name="optimizer_zero_grad", args=(current_epoch, i, ANY)), dict(name="Callback.on_before_backward", args=(trainer, model, ANY)), dict(name="on_before_backward", args=(ANY,)), # DeepSpeed handles backward internally - *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), + *([dict(name="backward", args=(ANY,))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), # note: unscaling happens here in the case of AMP - dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), - dict(name="on_before_optimizer_step", args=(ANY, 0)), + dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY)), + dict(name="on_before_optimizer_step", args=(ANY,)), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict( name="clip_gradients", @@ -336,17 +336,17 @@ def _auto_train_batch( ), dict( name="configure_gradient_clipping", - args=(ANY, 0), + args=(ANY,), kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), ), # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates # the actual call to `PrecisionPlugin.optimizer_step` dict( name="optimizer_step", - args=(current_epoch, i, ANY, 0, ANY), + args=(current_epoch, i, ANY, ANY), ), *( - [dict(name="lr_scheduler_step", args=(ANY, 0, None))] + [dict(name="lr_scheduler_step", args=(ANY, None))] if i == (trainer.num_training_batches - 1) else [] ), @@ -372,14 +372,14 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="Callback.on_before_backward", args=(trainer, model, ANY)), dict(name="on_before_backward", args=(ANY,)), # DeepSpeed handles backward internally - *([dict(name="backward", args=(ANY, None, None))] if not using_deepspeed else []), + *([dict(name="backward", args=(ANY,))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), # `manual_backward` calls the previous 3 dict(name="manual_backward", args=(ANY,)), dict(name="closure"), - dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)), - dict(name="on_before_optimizer_step", args=(ANY, 0)), + dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY)), + dict(name="on_before_optimizer_step", args=(ANY,)), *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), diff --git a/tests/tests_pytorch/plugins/precision/test_tpu.py b/tests/tests_pytorch/plugins/precision/test_tpu.py index a44ab5bc08b12..b3baf47a45b33 100644 --- a/tests/tests_pytorch/plugins/precision/test_tpu.py +++ b/tests/tests_pytorch/plugins/precision/test_tpu.py @@ -23,6 +23,6 @@ def test_optimizer_step_calls_mark_step(): plugin = TPUPrecisionPlugin() optimizer = Mock() with mock.patch("torch_xla.core.xla_model") as xm_mock: - plugin.optimizer_step(optimizer=optimizer, model=Mock(), optimizer_idx=0, closure=Mock()) + plugin.optimizer_step(optimizer=optimizer, model=Mock(), closure=Mock()) optimizer.step.assert_called_once() xm_mock.mark_step.assert_called_once() diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index aa13a63aea449..4720aee8af9dc 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -120,7 +120,7 @@ def configure_gradient_clipping(self, *args, **kwargs): # check clipping worked as expected self.check_grads_clipped() - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **_): + def optimizer_step(self, epoch, batch_idx, optimizer, closure, **_): # pass self as a kwarg optimizer.step(closure, pl_module=self) diff --git a/tests/tests_pytorch/strategies/test_colossalai.py b/tests/tests_pytorch/strategies/test_colossalai.py index 962ef49f32634..4a74d6bd6e23a 100644 --- a/tests/tests_pytorch/strategies/test_colossalai.py +++ b/tests/tests_pytorch/strategies/test_colossalai.py @@ -17,7 +17,6 @@ import torch import torch.nn.functional as F from torch import nn, Tensor -from torch.optim import Optimizer from torchmetrics import Accuracy from pytorch_lightning import LightningModule, seed_everything, Trainer @@ -143,7 +142,7 @@ def test_colossalai_optimizer(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, colossalai=True) def test_warn_colossalai_ignored(tmpdir): class TestModel(ModelParallelBoringModel): - def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + def backward(self, loss: Tensor, *args, **kwargs) -> None: return loss.backward() model = TestModel() diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index e864ae5c1031a..b4d977478ab90 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -22,7 +22,6 @@ import torch import torch.nn.functional as F from torch import nn, Tensor -from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Accuracy @@ -180,7 +179,7 @@ def test_deepspeed_defaults(): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_warn_deepspeed_ignored(tmpdir): class TestModel(BoringModel): - def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + def backward(self, loss: Tensor, *args, **kwargs) -> None: return loss.backward() model = TestModel() @@ -295,7 +294,6 @@ def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.optimizers[0].optimizer, torch.optim.SGD) assert isinstance(trainer.lr_scheduler_configs[0].scheduler, WarmupLR) assert trainer.lr_scheduler_configs[0].interval == "step" - assert trainer.lr_scheduler_configs[0].opt_idx == 0 model = BoringModel() lr_monitor = LearningRateMonitor() @@ -1098,9 +1096,8 @@ def test_deepspeed_configure_gradient_clipping(tmpdir): case of deepspeed.""" class TestModel(BoringModel): - def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm): - if optimizer_idx == 0: - self.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): + self.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) model = TestModel() trainer = Trainer( diff --git a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py index bcbc72d99e91b..78e227ccbfadf 100644 --- a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -104,62 +104,3 @@ def val_dataloader(self): ) trainer.fit(model) - - -def test_multiple_optimizers_multiple_dataloaders(tmpdir): - """Tests that only training_step can be used.""" - - class TestModel(BoringModel): - def on_train_epoch_start(self) -> None: - self.opt_0_seen = False - self.opt_1_seen = False - - def training_step(self, batch, batch_idx, optimizer_idx): - if optimizer_idx == 0: - self.opt_0_seen = True - elif optimizer_idx == 1: - self.opt_1_seen = True - else: - raise Exception("should only have two optimizers") - - self.training_step_called = True - loss = self.step(batch[0]) - return loss - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def validation_step(self, batch, batch_idx, dataloader_idx): - if dataloader_idx == 0: - assert batch.sum() == 0 - elif dataloader_idx == 1: - assert batch.sum() == 11 - else: - raise Exception("should only have two dataloaders") - - def val_dataloader(self): - dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) - dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) - return dl1, dl2 - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - - model = TestModel() - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - ) - - trainer.fit(model) - assert model.opt_0_seen - assert model.opt_1_seen diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index d77dd053b5646..7f2f0b61bebb4 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -929,8 +929,8 @@ def __init__(self): self.automatic_optimization = False def training_step(self, batch, batch_idx): - # Discriminator. optimizer1, optimizer2 = self.optimizers() + # Discriminator. self.toggle_optimizer(optimizer1) loss_d = self.step(batch) @@ -977,17 +977,3 @@ def configure_optimizers(self): assert set(trainer.logged_metrics) == {"loss_d", "loss_g"} assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"} - - -def test_manual_optimization_training_step_signature(tmpdir): - """Test that Lightning raises an exception if the training_step signature has an optimier_idx by mistake.""" - - class ConfusedAutomaticManualModel(ManualOptModel): - def training_step(self, batch, batch_idx, optimizer_idx): - return super().training_step(batch, batch_idx) - - model = ConfusedAutomaticManualModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) - - with pytest.raises(ValueError, match="Your `LightningModule.training_step` signature contains an `optimizer_idx`"): - trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 9c306fe8d2d74..60215766db1d0 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -26,62 +26,19 @@ def configure_optimizers(self): return opt_a, opt_b -def test_unbalanced_logging_with_multiple_optimizers(tmpdir): - """This tests ensures reduction works in unbalanced logging settings.""" +def test_multiple_optimizers_automatic_optimization_raises(): + """Test that multiple optimizers in automatic optimization is not allowed.""" - class TestModel(MultiOptModel): - - actual = {0: [], 1: []} - - def training_step(self, batch, batch_idx, optimizer_idx): - out = super().training_step(batch, batch_idx) - loss = out["loss"] - self.log(f"loss_{optimizer_idx}", loss, on_epoch=True) - self.actual[optimizer_idx].append(loss) - return out - - model = TestModel() - model.training_epoch_end = None - - # Initialize a trainer - trainer = pl.Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=5, limit_val_batches=5, enable_model_summary=False - ) - trainer.fit(model) - - for k, v in model.actual.items(): - assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1]) - # test loss is properly reduced - torch.testing.assert_close(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean()) - - -def test_multiple_optimizers(tmpdir): - class TestModel(MultiOptModel): - - seen = [False, False] - - def training_step(self, batch, batch_idx, optimizer_idx): - self.seen[optimizer_idx] = True - return super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 + class TestModel(BoringModel): + def configure_optimizers(self): + return torch.optim.Adam(self.parameters()), torch.optim.Adam(self.parameters()) model = TestModel() - model.val_dataloader = None - - trainer = pl.Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - ) - trainer.fit(model) + model.automatic_optimization = True - assert all(model.seen) + trainer = pl.Trainer() + with pytest.raises(RuntimeError, match="multiple optimizers is only supported with manual optimization"): + trainer.fit(model) def test_multiple_optimizers_manual(tmpdir): @@ -122,77 +79,3 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_step_called - - -def test_multiple_optimizers_no_opt_idx_argument(tmpdir): - """Test that an error is raised if no optimizer_idx is present when multiple optimizeres are passed in case of - automatic_optimization.""" - - class TestModel(MultiOptModel): - def training_step(self, batch, batch_idx): - return super().training_step(batch, batch_idx) - - trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2) - - with pytest.raises(ValueError, match="`training_step` is missing the `optimizer_idx`"): - trainer.fit(TestModel()) - - -def test_custom_optimizer_step_with_multiple_optimizers(tmpdir): - """This tests ensures custom optimizer_step works, even when optimizer.step is not called for a particular - optimizer.""" - - class TestModel(BoringModel): - training_step_called = [0, 0] - optimizer_step_called = [0, 0] - - def __init__(self): - super().__init__() - self.layer_a = torch.nn.Linear(32, 2) - self.layer_b = torch.nn.Linear(32, 2) - - def configure_optimizers(self): - opt_a = torch.optim.SGD(self.layer_a.parameters(), lr=0.001) - opt_b = torch.optim.SGD(self.layer_b.parameters(), lr=0.001) - return opt_a, opt_b - - def training_step(self, batch, batch_idx, optimizer_idx): - self.training_step_called[optimizer_idx] += 1 - x = self.layer_a(batch[0]) if (optimizer_idx == 0) else self.layer_b(batch[0]) - loss = torch.nn.functional.mse_loss(x, torch.ones_like(x)) - return loss - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array of batches with an entry per optimizer - assert len(outputs) == limit_train_batches - assert all(len(o) == 2 for o in outputs) - - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_): - # update first optimizer every step - if optimizer_idx == 0: - self.optimizer_step_called[optimizer_idx] += 1 - optimizer.step(closure=optimizer_closure) - - # update second optimizer every 2 steps - if optimizer_idx == 1: - if batch_idx % 2 == 0: - self.optimizer_step_called[optimizer_idx] += 1 - optimizer.step(closure=optimizer_closure) - else: - optimizer_closure() - - model = TestModel() - model.val_dataloader = None - - limit_train_batches = 4 - trainer = pl.Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - ) - trainer.fit(model) - assert len(model.training_step_called) == len(model.optimizer_step_called) == len(model.optimizers()) - assert model.training_step_called == [limit_train_batches, limit_train_batches] - assert model.optimizer_step_called == [limit_train_batches, limit_train_batches // 2] diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index c16d331800c61..0a833b15f7262 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -18,7 +18,7 @@ import torch from torch import optim -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.optimizer import ( _configure_optimizers, @@ -163,7 +163,6 @@ def configure_optimizers(self): frequency=1, reduce_on_plateau=True, strict=True, - opt_idx=0, name=None, ) @@ -178,24 +177,26 @@ def test_optimizer_return_options(tmpdir): opt_a = optim.Adam(model.parameters(), lr=0.002) opt_b = optim.SGD(model.parameters(), lr=0.002) scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10) - scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10) + optim.lr_scheduler.StepLR(opt_b, 10) # single optimizer model.configure_optimizers = lambda: opt_a - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) - assert len(opt) == 1 and len(lr_sched) == len(freq) == 0 + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) + assert len(opt) == 1 and len(lr_sched) == 0 # opt tuple + model.automatic_optimization = False model.configure_optimizers = lambda: (opt_a, opt_b) - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] - assert len(lr_sched) == len(freq) == 0 + assert len(lr_sched) == 0 # opt list + model.automatic_optimization = False model.configure_optimizers = lambda: [opt_a, opt_b] - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert opt == [opt_a, opt_b] - assert len(lr_sched) == len(freq) == 0 + assert len(lr_sched) == 0 ref_lr_sched = LRSchedulerConfig( scheduler=scheduler_a, @@ -205,48 +206,32 @@ def test_optimizer_return_options(tmpdir): monitor=None, strict=True, name=None, - opt_idx=0, ) # opt tuple of 2 lists + model.automatic_optimization = True model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 - assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt tuple of 1 list + model.automatic_optimization = True model.configure_optimizers = lambda: ([opt_a], scheduler_a) - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 - assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt single dictionary + model.automatic_optimization = True model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) + opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 1 - assert len(freq) == 0 assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched - # opt multiple dictionaries with frequencies - model.configure_optimizers = lambda: ( - {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, - {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, - ) - opt, lr_sched, freq = _init_optimizers_and_lr_schedulers(model) - assert len(opt) == len(lr_sched) == len(freq) == 2 - assert opt[0] == opt_a - ref_lr_sched.opt_idx = 0 - assert lr_sched[0] == ref_lr_sched - ref_lr_sched.scheduler = scheduler_b - ref_lr_sched.opt_idx = 1 - assert lr_sched[1] == ref_lr_sched - assert freq == [1, 5] - def test_none_optimizer(tmpdir): model = BoringModel() @@ -267,74 +252,6 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" - - -@pytest.mark.parametrize( - "schedulers, kwargs, intervals, frequencies, expected_steps, max_epochs", - [ - ( - (optim.lr_scheduler.OneCycleLR, optim.lr_scheduler.OneCycleLR), - (dict(max_lr=0.01, total_steps=3), dict(max_lr=0.01, total_steps=2)), - ("step", "step"), - (3, 2), - (4, 3), - 1, - ), - ( - (optim.lr_scheduler.OneCycleLR, optim.lr_scheduler.OneCycleLR), - (dict(max_lr=0.01, total_steps=5), dict(max_lr=0.01, total_steps=5)), - ("step", "step"), - (None, None), - (6, 6), - 1, - ), - ( - (optim.lr_scheduler.StepLR, optim.lr_scheduler.CosineAnnealingLR), - (dict(step_size=5), dict(T_max=2)), - ("epoch", "epoch"), - (5, 10), - (2, 3), - 3, - ), - ], -) -def test_step_scheduling_for_multiple_optimizers_with_frequency( - tmpdir, schedulers, kwargs, intervals, frequencies, expected_steps, max_epochs -): - """Test that step LR schedulers for multiple optimizers follow the optimizer frequencies when corresponding - frequency is set.""" - - class DummyModel(BoringModel): - def training_step(self, batch, batch_idx, optimizer_idx): - return super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs) -> None: - pass - - def configure_optimizers(self): - optimizer1 = optim.Adam(self.parameters(), lr=0.01) - optimizer2 = optim.Adam(self.parameters(), lr=0.01) - - lr_scheduler_config_1 = {"scheduler": schedulers[0](optimizer1, **kwargs[0]), "interval": intervals[0]} - lr_scheduler_config_2 = {"scheduler": schedulers[1](optimizer2, **kwargs[1]), "interval": intervals[1]} - - return [ - {"optimizer": optimizer1, "frequency": frequencies[0], "lr_scheduler": lr_scheduler_config_1}, - {"optimizer": optimizer2, "frequency": frequencies[1], "lr_scheduler": lr_scheduler_config_2}, - ] - - model = DummyModel() - - trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=1, limit_train_batches=5, max_epochs=max_epochs) - trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" - - assert trainer.lr_scheduler_configs[0].opt_idx == 0 - assert trainer.lr_scheduler_configs[1].opt_idx == 1 - # Step count is 1 greater than the expected value because scheduler.step() is called once during initialization - assert trainer.lr_scheduler_configs[0].scheduler._step_count == expected_steps[0] - assert trainer.lr_scheduler_configs[1].scheduler._step_count == expected_steps[1] @pytest.mark.parametrize("fn", ("validate", "test", "predict")) @@ -355,52 +272,6 @@ def configure_optimizers(self): assert len(trainer.lr_scheduler_configs) == 0 assert len(trainer.optimizers) == 0 - assert len(trainer.optimizer_frequencies) == 0 - - -def test_multiple_optimizers_callbacks(tmpdir): - """Tests that multiple optimizers can be used with callbacks.""" - - class CB(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - pass - - def on_train_epoch_start(self, trainer, pl_module): - pass - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.layer_1 = torch.nn.Linear(32, 2) - self.layer_2 = torch.nn.Linear(32, 2) - - def training_step(self, batch, batch_idx, optimizer_idx): - if optimizer_idx == 0: - a = batch[0] - acc = self.layer_1(a) - else: - a = batch[0] - acc = self.layer_2(a) - - acc = self.loss(acc, acc) - return acc - - def configure_optimizers(self): - a = optim.RMSprop(self.layer_1.parameters(), 1e-2) - b = optim.RMSprop(self.layer_2.parameters(), 1e-2) - return a, b - - model = TestModel() - model.training_epoch_end = None - trainer = Trainer( - callbacks=[CB()], - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=2, - max_epochs=1, - enable_model_summary=False, - ) - trainer.fit(model) @pytest.mark.parametrize("complete_epoch", [True, False]) @@ -534,24 +405,6 @@ def configure_optimizers(self): trainer.fit(model) -def test_invalid_opt_idx_in_scheduler(tmpdir): - """Test exception when incorrect opt_idx is set in lr_scheduler config.""" - - class InvalidOptimizerModel(BoringModel): - def configure_optimizers(self): - opt1 = optim.SGD(self.layer.parameters(), lr=0.1) - opt2 = optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = {"scheduler": optim.lr_scheduler.StepLR(opt2, step_size=1), "opt_idx": 0} - return [opt1, opt2], [lr_scheduler] - - model = InvalidOptimizerModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises( - MisconfigurationException, match="`opt_idx` .* does not match with the index of the respective optimizer" - ): - trainer.fit(model) - - def test_invalid_optimizer_dict_raises(tmpdir): """Test exception when lr_scheduler dict has no scheduler.""" @@ -661,7 +514,7 @@ def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir, save_on callbacks=[ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=save_on_train_epoch_end)], ) - class TestModel(BoringModel): + class Model(BoringModel): def training_step(self, batch, batch_idx): self.log("foo", batch_idx) return super().training_step(batch, batch_idx) @@ -686,7 +539,7 @@ def on_save_checkpoint(self, checkpoint): self.on_save_checkpoint_called = True - model = TestModel() + model = Model() model.training_epoch_end = None trainer.fit(model) assert model.on_save_checkpoint_called @@ -709,10 +562,10 @@ def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): - def lr_scheduler_step(self, scheduler, optimizer_idx: int, metric): + def lr_scheduler_step(self, scheduler: int, metric): # step-level if isinstance(scheduler, torch.optim.lr_scheduler.StepLR): - super().lr_scheduler_step(scheduler, optimizer_idx, metric) + super().lr_scheduler_step(scheduler, metric) # epoch-level, custom scheduler elif isinstance(scheduler, CustomEpochScheduler): scheduler.step(epoch=self.current_epoch) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index edbd1998b3e16..43de40031e450 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -32,13 +32,14 @@ from tests_pytorch.helpers.utils import getattr_recursive -def test_error_on_more_than_1_optimizer(tmpdir): +def test_error_with_multiple_optimizers(tmpdir): """Check that error is thrown when more than 1 optimizer is passed.""" class CustomBoringModel(BoringModel): def __init__(self, lr): super().__init__() self.save_hyperparameters() + self.automatic_optimization = False def configure_optimizers(self): optimizer1 = torch.optim.SGD(self.parameters(), lr=self.hparams.lr) @@ -47,7 +48,6 @@ def configure_optimizers(self): model = CustomBoringModel(lr=1e-2) - # logger file to get meta trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) tuner = Tuner(trainer) diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 11c804be9190b..c81e30a4b7c64 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest import torch @@ -192,3 +192,33 @@ def test_migrate_loop_structure_after_tbptt_removal(): "epoch_loop.manual_loop.optim_step_progress": optim_progress_manual, } } + + +def test_migrate_loop_structure_after_optimizer_loop_removal(): + """Test the loop state migration after multiple optimizer support in automatic optimization was removed in + 2.0.0.""" + state_automatic = MagicMock() + optim_progress_automatic = { + "optimizer": MagicMock(), + "optimizer_position": 33, + } + old_checkpoint = { + "loops": { + "fit_loop": { + "epoch_loop.state_dict": {"any": "state"}, + "epoch_loop.batch_loop.state_dict": MagicMock(), + "epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic, + } + } + } + _set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0") + assert updated_checkpoint["loops"] == { + "fit_loop": { + "epoch_loop.state_dict": ANY, + "epoch_loop.optimizer_loop.state_dict": state_automatic, + # optimizer_position gets dropped: + "epoch_loop.optimizer_loop.optim_progress": {"optimizer": ANY}, + } + }