Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

progressive restoring of trainer state #7652

Merged
merged 36 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a2560f1
deprecate
awaelchli Jun 12, 2021
88f2015
test
awaelchli Jun 12, 2021
b0c0b07
tests
awaelchli Jun 12, 2021
0f17119
ypf
awaelchli Jun 12, 2021
3aef4e4
all
awaelchli Jun 12, 2021
3cc54b8
clean up
awaelchli Jun 12, 2021
0fa9807
clean up
awaelchli Jun 12, 2021
f62cd51
test hook calls
awaelchli Jun 12, 2021
09dd67d
space
awaelchli Jun 12, 2021
ce2887e
Merge branch 'feature/resume-8' into feature/resume-9
awaelchli Jun 12, 2021
1dea5be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2021
1a2c6e4
unused import
awaelchli Jun 12, 2021
4695bec
Merge branch 'feature/resume-8' into feature/resume-9
awaelchli Jun 12, 2021
af25625
fix info message
awaelchli Jun 12, 2021
de6e6d9
move
awaelchli Jun 12, 2021
d89a98e
update changelog
awaelchli Jun 12, 2021
777a297
chlog
awaelchli Jun 12, 2021
ce98239
test moving resume end after pre_dispatch
awaelchli Jun 14, 2021
872157f
Merge branch 'master' into feature/resume-9
awaelchli Jun 14, 2021
492856c
wip
awaelchli Jun 14, 2021
71c74fe
move
awaelchli Jun 15, 2021
0f88897
Revert "wip"
awaelchli Jun 14, 2021
1480442
move misplaced resume_end()
awaelchli Jun 16, 2021
c94fd78
Merge branch 'master' into feature/resume-9
awaelchli Jun 16, 2021
a987f1d
add guard to restore_datamodule
awaelchli Jun 16, 2021
c8ef693
rm duplicate comment
awaelchli Jun 16, 2021
6abf23a
Merge branch 'master' into feature/resume-9
awaelchli Jun 16, 2021
6a38c9b
add hook test
awaelchli Jun 16, 2021
d208b5c
comment
awaelchli Jun 16, 2021
25918e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2021
3d9539d
blank
awaelchli Jun 16, 2021
f018bf8
Merge remote-tracking branch 'origin/feature/resume' into feature/res…
awaelchli Jun 16, 2021
b45c335
merge tests
awaelchli Jun 16, 2021
0b4f7a7
Merge branch 'master' into feature/resume-9
awaelchli Jun 16, 2021
7e38deb
clarify how many batches need to run
awaelchli Jun 16, 2021
02c3d54
Update tests/models/test_hooks.py
awaelchli Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))


- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


### Deprecated


Expand Down
69 changes: 38 additions & 31 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -82,7 +82,8 @@ def resume_start(self) -> None:

def resume_end(self) -> None:
""" Signal the connector that all states have resumed and memory for the checkpoint object can be released. """
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
if self.resume_checkpoint_path:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
self.resume_checkpoint_path = None
self._loaded_checkpoint = dict()

Expand All @@ -93,53 +94,63 @@ def resume_end(self) -> None:
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:
def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
"""
Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file
Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file
through file-read and state-restore, in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
3. don't restore

All restored states are listed in return value description of `dump_checkpoint`.

Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
model = self.trainer.lightning_module

self.restore_model_state(model, self._loaded_checkpoint)
# restore module states
self.restore_datamodule()
self.restore_model()

if self.trainer._device_type == DeviceType.GPU:
model.cuda(self.trainer.root_gpu)
# restore callback states
self.restore_callbacks()

# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint)

self.restore_training_state()
self.resume_end()
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
def restore_datamodule(self) -> None:
""" Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """
if not self._loaded_checkpoint:
return

datamodule = self.trainer.datamodule
if datamodule is not None:
datamodule.on_load_checkpoint(self._loaded_checkpoint)

def restore_model(self) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give
the LightningModule a chance to modify the contents, then finally the model gets updated with
the loaded weights.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
""" Restore only the model weights. """
Expand All @@ -150,19 +161,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) ->
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
def restore_training_state(self) -> None:
"""
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)

self.restore_callbacks()

self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
# restore progress (loops etc.)
self.restore_progress()

Expand Down Expand Up @@ -232,10 +240,8 @@ def restore_optimizers(self) -> None:
return

# restore the optimizers
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
Expand Down Expand Up @@ -357,6 +363,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

self.trainer.precision_plugin.on_save_checkpoint(checkpoint)

# dump hyper-parameters
# dump hyper-parameters
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if model.hparams:
if hasattr(model, '_hparams_name'):
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,8 @@ def fit(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self.checkpoint_connector.resume_start()

self._run(model)

assert self.state.stopped
Expand Down Expand Up @@ -801,6 +803,13 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
self.accelerator.connect(model)
self.accelerator.setup_environment()
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment

# restore modules after setup
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

Expand Down Expand Up @@ -842,6 +851,9 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()

Expand Down Expand Up @@ -904,6 +916,8 @@ def _pre_training_routine(self):
# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()

self.checkpoint_connector.resume_end()

# --------------------------
# Pre-train
# --------------------------
Expand All @@ -917,9 +931,6 @@ def _pre_training_routine(self):
if self.is_global_zero and self.weights_summary is not None and not self.testing:
ref_model.summarize(mode=self.weights_summary)

# restore training and model before hpc is called
self.checkpoint_connector.restore()

# on pretrain routine end
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def configure_optimizers(self):
model = FreezeModel()
cb = OnEpochLayerFinetuning()
trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model)


Expand Down