Skip to content

Commit

Permalink
split restore_training_state into logical parts [2 / 2] (#7900)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jun 10, 2021
1 parent d209b68 commit c1eac48
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 68 deletions.
78 changes: 12 additions & 66 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import torch

Expand Down Expand Up @@ -115,7 +115,7 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:

# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states)
self.restore_training_state(self._loaded_checkpoint)

self.resume_end()
return True
Expand All @@ -135,77 +135,23 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
if not checkpoint:
return

# validation
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)

if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)

# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)

# restore callback states
self.trainer.on_load_checkpoint(checkpoint)

self.trainer.train_loop.global_step = checkpoint['global_step']
self.trainer.train_loop.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
"""
raise MisconfigurationException(m)

# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
)

if not load_optimizer_states:
return

# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
self.restore_callbacks()

# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
# restore progress (loops etc.)
self.restore_progress()

# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
self.restore_optimizers_and_schedulers()

def restore_callbacks(self) -> None:
""" Restores all callbacks from the pre-loaded checkpoint. """
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
callbacks=[early_stop_callback],
)

with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
with pytest.raises(MisconfigurationException, match=r'You restored a checkpoint with current_epoch'):
new_trainer.fit(model)


Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_model_checkpoint_only_weights(tmpdir):

# assert restoring train state fails
with pytest.raises(KeyError, match="checkpoint contains only the model"):
trainer.checkpoint_connector.restore_training_state(checkpoint)
trainer.checkpoint_connector.restore(new_weights_path)


def test_model_freeze_unfreeze():
Expand Down

0 comments on commit c1eac48

Please sign in to comment.