diff --git a/CHANGELOG.md b/CHANGELOG.md index 75da2d41716fa..3ff0677690570 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -187,6 +187,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) +- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) + + ### Deprecated diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8035f0c532764..b599caf91e20d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,7 +15,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Optional, Union import torch @@ -82,7 +82,8 @@ def resume_start(self) -> None: def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ - rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") + if self.resume_checkpoint_path: + rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None self._loaded_checkpoint = dict() @@ -93,9 +94,9 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: + def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """ - Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file + Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: 1. from HPC weights if found @@ -103,43 +104,53 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: 3. don't restore All restored states are listed in return value description of `dump_checkpoint`. + + Args: + checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path + self.resume_checkpoint_path = checkpoint_path self.resume_start() - model = self.trainer.lightning_module - self.restore_model_state(model, self._loaded_checkpoint) + # restore module states + self.restore_datamodule() + self.restore_model() - if self.trainer._device_type == DeviceType.GPU: - model.cuda(self.trainer.root_gpu) + # restore callback states + self.restore_callbacks() # restore training state - if self._loaded_checkpoint: - self.restore_training_state(self._loaded_checkpoint) - + self.restore_training_state() self.resume_end() - return True - def restore_model_state(self, model: LightningModule, checkpoint) -> None: + def restore_datamodule(self) -> None: + """ Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """ + if not self._loaded_checkpoint: + return + + datamodule = self.trainer.datamodule + if datamodule is not None: + datamodule.on_load_checkpoint(self._loaded_checkpoint) + + def restore_model(self) -> None: """ - Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give + the LightningModule a chance to modify the contents, then finally the model gets updated with + the loaded weights. """ - if not checkpoint: + if not self._loaded_checkpoint: return - # restore datamodule states - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_load_checkpoint(checkpoint) + model = self.trainer.lightning_module # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(checkpoint) + model.on_load_checkpoint(self._loaded_checkpoint) # call hpc specific hook if self.hpc_resume_path is not None: model.on_hpc_load(self._loaded_checkpoint) # restore model state_dict - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) + self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """ Restore only the model weights. """ @@ -150,19 +161,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) - def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: + def restore_training_state(self) -> None: """ Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. """ - if not checkpoint: + if not self._loaded_checkpoint: return # restore precision plugin (scaler etc.) - self.trainer.precision_plugin.on_load_checkpoint(checkpoint) - - self.restore_callbacks() - + self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) # restore progress (loops etc.) self.restore_progress() @@ -232,10 +240,8 @@ def restore_optimizers(self) -> None: return # restore the optimizers - optimizer_states = self._loaded_checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - + self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint) + for optimizer in self.trainer.optimizers: # move optimizer to GPU 1 weight at a time # avoids OOM if self.trainer.root_gpu is not None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 158fe20beee77..6979a859b0e9a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -492,6 +492,8 @@ def fit( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) + self.checkpoint_connector.resume_start() + self._run(model) assert self.state.stopped @@ -801,6 +803,13 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED self.accelerator.connect(model) self.accelerator.setup_environment() self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + + # restore modules after setup + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -842,6 +851,9 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -904,6 +916,8 @@ def _pre_training_routine(self): # register auto-resubmit when on SLURM self.slurm_connector.register_slurm_signal_handlers() + self.checkpoint_connector.resume_end() + # -------------------------- # Pre-train # -------------------------- @@ -917,9 +931,6 @@ def _pre_training_routine(self): if self.is_global_zero and self.weights_summary is not None and not self.testing: ref_model.summarize(mode=self.weights_summary) - # restore training and model before hpc is called - self.checkpoint_connector.restore() - # on pretrain routine end self.on_pretrain_routine_end() ref_model.on_pretrain_routine_end() diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 0d6a0e3f0a3d1..ef421e9219725 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -275,7 +275,7 @@ def configure_optimizers(self): model = FreezeModel() cb = OnEpochLayerFinetuning() trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) - with pytest.raises(IndexError, match="index 6 is out of range"): + with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"): trainer.fit(model) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7ab93e9ad2621..9938f756bf3fa 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -418,17 +418,31 @@ def test_trainer_model_hook_system_fit(tmpdir): assert called == expected -def test_trainer_model_hook_system_fit_no_val(tmpdir): +def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): + # initial training to get a checkpoint + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=0, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(model) + best_model_path = trainer.checkpoint_callback.best_model_path + + # resume from checkpoint with HookedModel called = [] model = HookedModel(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + # already performed 1 step, now resuming to do an additional 2 + max_steps=(1 + train_batches), limit_val_batches=0, - limit_train_batches=train_batches, progress_bar_refresh_rate=0, weights_summary=None, + resume_from_checkpoint=best_model_path, ) assert called == [] trainer.fit(model) @@ -436,6 +450,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'prepare_data', 'configure_callbacks', 'setup', + 'on_load_checkpoint', 'configure_sharded_model', 'configure_optimizers', 'on_fit_start',