Skip to content

Commit

Permalink
progressive restoring of trainer state (#7652)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jun 17, 2021
1 parent 3fece17 commit eebdc91
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))


- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


### Deprecated


Expand Down
68 changes: 37 additions & 31 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -82,7 +82,8 @@ def resume_start(self) -> None:

def resume_end(self) -> None:
""" Signal the connector that all states have resumed and memory for the checkpoint object can be released. """
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
if self.resume_checkpoint_path:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
self.resume_checkpoint_path = None
self._loaded_checkpoint = dict()

Expand All @@ -93,53 +94,63 @@ def resume_end(self) -> None:
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:
def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
"""
Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file
Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file
through file-read and state-restore, in this priority:
1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
3. don't restore
All restored states are listed in return value description of `dump_checkpoint`.
Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
model = self.trainer.lightning_module

self.restore_model_state(model, self._loaded_checkpoint)
# restore module states
self.restore_datamodule()
self.restore_model()

if self.trainer._device_type == DeviceType.GPU:
model.cuda(self.trainer.root_gpu)
# restore callback states
self.restore_callbacks()

# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint)

self.restore_training_state()
self.resume_end()
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
def restore_datamodule(self) -> None:
""" Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """
if not self._loaded_checkpoint:
return

datamodule = self.trainer.datamodule
if datamodule is not None:
datamodule.on_load_checkpoint(self._loaded_checkpoint)

def restore_model(self) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give
the LightningModule a chance to modify the contents, then finally the model gets updated with
the loaded weights.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
""" Restore only the model weights. """
Expand All @@ -150,19 +161,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) ->
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
def restore_training_state(self) -> None:
"""
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)

self.restore_callbacks()

self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
# restore progress (loops etc.)
self.restore_progress()

Expand Down Expand Up @@ -232,10 +240,8 @@ def restore_optimizers(self) -> None:
return

# restore the optimizers
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,8 @@ def fit(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self.checkpoint_connector.resume_start()

self._run(model)

assert self.state.stopped
Expand Down Expand Up @@ -801,6 +803,13 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
self.accelerator.connect(model)
self.accelerator.setup_environment()
self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment

# restore modules after setup
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

self._call_configure_sharded_model(model) # allow user to setup in model sharded environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

Expand Down Expand Up @@ -842,6 +851,9 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED
# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()

Expand Down Expand Up @@ -904,6 +916,8 @@ def _pre_training_routine(self):
# register auto-resubmit when on SLURM
self.slurm_connector.register_slurm_signal_handlers()

self.checkpoint_connector.resume_end()

# --------------------------
# Pre-train
# --------------------------
Expand All @@ -917,9 +931,6 @@ def _pre_training_routine(self):
if self.is_global_zero and self.weights_summary is not None and not self.testing:
ref_model.summarize(mode=self.weights_summary)

# restore training and model before hpc is called
self.checkpoint_connector.restore()

# on pretrain routine end
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def configure_optimizers(self):
model = FreezeModel()
cb = OnEpochLayerFinetuning()
trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"):
trainer.fit(model)


Expand Down
21 changes: 18 additions & 3 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,24 +418,39 @@ def test_trainer_model_hook_system_fit(tmpdir):
assert called == expected


def test_trainer_model_hook_system_fit_no_val(tmpdir):
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
# initial training to get a checkpoint
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=0,
progress_bar_refresh_rate=0,
weights_summary=None,
)
trainer.fit(model)
best_model_path = trainer.checkpoint_callback.best_model_path

# resume from checkpoint with HookedModel
called = []
model = HookedModel(called)
train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
# already performed 1 step, now resuming to do an additional 2
max_steps=(1 + train_batches),
limit_val_batches=0,
limit_train_batches=train_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
resume_from_checkpoint=best_model_path,
)
assert called == []
trainer.fit(model)
expected = [
'prepare_data',
'configure_callbacks',
'setup',
'on_load_checkpoint',
'configure_sharded_model',
'configure_optimizers',
'on_fit_start',
Expand Down

0 comments on commit eebdc91

Please sign in to comment.