Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make manual optimization mandatory for multiple optimizers #16539

Merged
merged 52 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f3da9c1
remove optimizer_idx from code base
awaelchli Jan 28, 2023
89ba1f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2023
1d30f6c
frequency
awaelchli Jan 28, 2023
e08f559
remove test
awaelchli Jan 28, 2023
bb5ce22
fixes
awaelchli Jan 28, 2023
454f273
more tests
awaelchli Jan 28, 2023
041bdb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2023
5f4f161
fix
awaelchli Jan 29, 2023
973c197
fixes
awaelchli Jan 29, 2023
c691330
wip
awaelchli Jan 29, 2023
b29bef7
test loops
awaelchli Jan 29, 2023
19e6463
fixes
awaelchli Jan 29, 2023
0b51930
fix none loss return
awaelchli Jan 29, 2023
2ab83c6
gan
awaelchli Jan 29, 2023
0b6b725
fixes
awaelchli Jan 29, 2023
23a3af2
convert test
awaelchli Jan 29, 2023
7daede6
convert tests
awaelchli Jan 29, 2023
8ad4895
delete code
awaelchli Jan 29, 2023
8f84246
fixes
awaelchli Jan 29, 2023
8ed6c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2023
af7bc8a
undo rename, do it in follow up
awaelchli Jan 29, 2023
71b8fd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2023
a041772
update test
awaelchli Jan 30, 2023
a6b2470
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
0dc1a0a
fix tests
awaelchli Jan 30, 2023
6759f02
Merge remote-tracking branch 'origin/removal/optimizer-loop' into rem…
awaelchli Jan 30, 2023
b57a7a9
rework test
awaelchli Jan 30, 2023
5537bbd
fix lightning optimizer
awaelchli Jan 30, 2023
2ac387a
fix
awaelchli Jan 30, 2023
4db7f7e
fix ipu
awaelchli Jan 30, 2023
0c99e83
fix deepspeed
awaelchli Jan 30, 2023
1cec14e
deepspeed
awaelchli Jan 30, 2023
20988a3
unused
awaelchli Jan 30, 2023
b64879a
fix
awaelchli Jan 30, 2023
f0048f1
fixes
awaelchli Jan 30, 2023
c298276
Merge branch 'master' into removal/optimizer-loop
awaelchli Jan 30, 2023
07765f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
42c2905
Merge branch 'master' into removal/optimizer-loop
awaelchli Jan 30, 2023
64a318f
address todos
awaelchli Jan 30, 2023
c0d5c00
migration wip
awaelchli Jan 30, 2023
f4d398e
Merge branch 'master' into removal/optimizer-loop
awaelchli Jan 30, 2023
2aa1fb3
test for migration
awaelchli Jan 30, 2023
76ea50e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
2b29352
todo
awaelchli Jan 30, 2023
b65ea43
fix merge conflict
awaelchli Jan 30, 2023
f921a27
huge changelog
awaelchli Jan 30, 2023
b41e080
accidental change
awaelchli Jan 30, 2023
476fe6b
accidental bad merge
awaelchli Jan 30, 2023
541e1fa
Update src/pytorch_lightning/core/module.py
awaelchli Jan 31, 2023
271bf50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2023
35a1f7b
Merge branch 'master' into removal/optimizer-loop
awaelchli Feb 1, 2023
941d135
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
"""Called after ``loss.backward()`` and before optimizers are stepped."""

def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
"""Called before ``optimizer.step()``."""

Expand Down
19 changes: 6 additions & 13 deletions src/pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BaseFinetuning(Callback):
and should be used to freeze any modules parameters.
``finetune_function``: This method is called on every train epoch start and should be used to
``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group``
``unfreeze`` any parameters. Those parameters need to be added in a new ``param_group``
within the optimizer.
.. note:: Make sure to filter the parameters based on ``requires_grad``.
Expand All @@ -69,7 +69,7 @@ class BaseFinetuning(Callback):
... # Here, we are freezing `feature_extractor`
... self.freeze(pl_module.feature_extractor)
...
... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
... def finetune_function(self, pl_module, current_epoch, optimizer):
... # When `current_epoch` is 10, feature_extractor will start training.
... if current_epoch == self._unfreeze_at_epoch:
... self.unfreeze_and_add_param_group(
Expand Down Expand Up @@ -290,18 +290,13 @@ def _store(

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the epoch begins."""
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers

for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0):
for opt_idx, optimizer in enumerate(trainer.optimizers):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
self.finetune_function(pl_module, trainer.current_epoch, optimizer)
current_param_groups = optimizer.param_groups
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)

def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None:
"""Override to add your unfreeze logic."""
raise NotImplementedError

Expand Down Expand Up @@ -389,9 +384,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
self.freeze(pl_module.backbone)

def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None:
"""Called when the epoch begins."""
if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]["lr"]
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
# We assert that there is only one optimizer on fit start
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

if trainer.lr_scheduler_configs:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
batch_idx: the index of the batch
"""

# TODO: Should 'outputs' be renamed to 'output' (singular)?
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
"""Called in the training loop after the batch.

Expand Down Expand Up @@ -229,7 +230,7 @@ def on_after_backward(self) -> None:
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
"""Called before ``optimizer.step()``.

If using gradient accumulation, the hook is called once the gradients have been accumulated.
Expand All @@ -243,11 +244,10 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) ->

Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.

Example::

def on_before_optimizer_step(self, optimizer, optimizer_idx):
def on_before_optimizer_step(self, optimizer):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
Expand Down
Loading