Skip to content

Commit

Permalink
Make manual optimization mandatory for multiple optimizers (#16539)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Feb 1, 2023
1 parent df09370 commit 6a56586
Show file tree
Hide file tree
Showing 56 changed files with 344 additions and 1,512 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
def freeze_before_training(self, pl_module: LightningModule):
self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)

def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer):
if epoch == self.milestones[0]:
# unfreeze 5 last layers
self.unfreeze_and_add_param_group(
Expand Down
20 changes: 20 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560))


- Manual optimization is now required for working with multiple optimizers ([#16539](https://github.com/Lightning-AI/lightning/pull/16539))


### Deprecated

-
Expand Down Expand Up @@ -173,6 +176,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the `using_lbfgs` argument from `LightningModule.optimizer_step` hook ([#16538](https://github.com/Lightning-AI/lightning/pull/16538))


- Removed support for multiple optimizers in automatic optimization mode ([#16539](https://github.com/Lightning-AI/lightning/pull/16539))
* Removed `opt_idx` argument from `BaseFinetuning.finetune_function` callback method
* Removed `opt_idx` argument from `Callback.on_before_optimizer_step` callback method
* Removed `optimizer_idx` as an optional argument in `LightningModule.training_step`
* Removed `optimizer_idx` argument from `LightningModule.on_before_optimizer_step`
* Removed `optimizer_idx` argument from `LightningModule.configure_gradient_clipping`
* Removed `optimizer_idx` argument from `LightningModule.optimizer_step`
* Removed `optimizer_idx` argument from `LightningModule.optimizer_zero_grad`
* Removed `optimizer_idx` argument from `LightningModule.lr_scheduler_step`
* Removed support for declaring optimizer frequencies in the dictionary returned from `LightningModule.configure_optimizers`
* Removed arguments `optimizer` and `optimizer_idx` from `LightningModule.backward`
* Removed `optimizer_idx` argument from `PrecisionPlugin.optimizer_step` and all of its overrides in subclasses
* Removed `optimizer_idx` argument from `PrecisionPlugin.{optimizer_step,backward}` and all of its overrides in subclasses
* Removed `optimizer_idx` argument from `Strategy.{optimizer_step,backward}` and all of its overrides in subclasses
* Removed `Trainer.optimizer_frequencies` attribute


### Fixed

- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369))
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
"""Called after ``loss.backward()`` and before optimizers are stepped."""

def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
"""Called before ``optimizer.step()``."""

Expand Down
19 changes: 6 additions & 13 deletions src/pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BaseFinetuning(Callback):
and should be used to freeze any modules parameters.
``finetune_function``: This method is called on every train epoch start and should be used to
``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group``
``unfreeze`` any parameters. Those parameters need to be added in a new ``param_group``
within the optimizer.
.. note:: Make sure to filter the parameters based on ``requires_grad``.
Expand All @@ -69,7 +69,7 @@ class BaseFinetuning(Callback):
... # Here, we are freezing `feature_extractor`
... self.freeze(pl_module.feature_extractor)
...
... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
... def finetune_function(self, pl_module, current_epoch, optimizer):
... # When `current_epoch` is 10, feature_extractor will start training.
... if current_epoch == self._unfreeze_at_epoch:
... self.unfreeze_and_add_param_group(
Expand Down Expand Up @@ -290,18 +290,13 @@ def _store(

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the epoch begins."""
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers

for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0):
for opt_idx, optimizer in enumerate(trainer.optimizers):
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
self.finetune_function(pl_module, trainer.current_epoch, optimizer)
current_param_groups = optimizer.param_groups
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)

def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None:
"""Override to add your unfreeze logic."""
raise NotImplementedError

Expand Down Expand Up @@ -389,9 +384,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
self.freeze(pl_module.backbone)

def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None:
"""Called when the epoch begins."""
if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]["lr"]
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
# We assert that there is only one optimizer on fit start
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

if trainer.lr_scheduler_configs:
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def on_after_backward(self) -> None:
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
"""Called before ``optimizer.step()``.
If using gradient accumulation, the hook is called once the gradients have been accumulated.
Expand All @@ -243,11 +243,10 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) ->
Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
Example::
def on_before_optimizer_step(self, optimizer, optimizer_idx):
def on_before_optimizer_step(self, optimizer):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
Expand Down
Loading

0 comments on commit 6a56586

Please sign in to comment.