Skip to content

Commit

Permalink
Unify checkpoint load paths [redo Lightning-AI#9693] (Lightning-AI#10061
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jjenniferdai authored and ninginthecloud committed Oct 27, 2021
1 parent 458769c commit 2db5dd0
Show file tree
Hide file tree
Showing 32 changed files with 209 additions and 180 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))


- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))


### Changed

Expand Down Expand Up @@ -415,6 +417,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated access to the `AcceleratorConnector.configure_slurm_ddp` method and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))


- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/advanced/advanced_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ After training using ZeRO Stage 3, you'll notice that your checkpoints are a dir
.. warning::

This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the `resume_from_checkpoint` Trainer argument. Ensure to keep the sharded checkpoint directory if this is required.
This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the ``trainer.fit(ckpt_path=)`` call. Ensure to keep the sharded checkpoint directory if this is required.

Custom DeepSpeed Config
"""""""""""""""""""""""
Expand Down
4 changes: 4 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,10 @@ By setting to False, you have to add your own distributed sampler:
resume_from_checkpoint
^^^^^^^^^^^^^^^^^^^^^^

.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.


.. raw:: html

<video width="50%" max-width="400px" controls
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ do the following:
.. code-block:: python
model = LitModel()
trainer = Trainer(resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
2 changes: 1 addition & 1 deletion docs/source/extensions/loops_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and
def on_load_checkpoint(self, state_dict):
self.iteration = state_dict["iteration"]
When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
When the Trainer is restarting from a checkpoint (e.g., through :code:`trainer.fit(ckpt_path=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint.
For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example:

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
"Be aware that when using `resume_from_checkpoint`,"
" callbacks used to create the checkpoint need to be provided."
"Be aware that when using `ckpt_path`,"
" callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
f" Please add the following callbacks: {list(difference)}.",
UserWarning,
)
Expand Down
41 changes: 24 additions & 17 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self.resume_checkpoint_path: Optional[_PATH] = None
# TODO: remove resume_from_checkpoint_fit_path in v1.7
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead."
)
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand All @@ -53,14 +60,14 @@ def hpc_resume_path(self) -> Optional[str]:
if os.path.exists(auto_save_checkpoint):
return auto_save_checkpoint

def resume_start(self) -> None:
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore
"""
self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
return
Expand All @@ -83,8 +90,18 @@ def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any
def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
released."""
assert self.trainer.state.fn is not None
if self.resume_checkpoint_path:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
if self.trainer.state.fn == TrainerFn.FITTING:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
# TODO: remove resume_from_checkpoint_fit_path in v1.7
if (
self.trainer.state.fn == TrainerFn.FITTING
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path
):
self.resume_from_checkpoint_fit_path = None
self.resume_checkpoint_path = None
self._loaded_checkpoint = {}

Expand All @@ -99,16 +116,15 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
state-restore, in this priority:
1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore
All restored states are listed in return value description of `dump_checkpoint`.
Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
self.resume_start(checkpoint_path)

# restore module states
self.restore_datamodule()
Expand Down Expand Up @@ -157,15 +173,6 @@ def restore_model(self) -> None:
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.
Expand Down
88 changes: 41 additions & 47 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ def __init__(
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
.. deprecated:: v1.5
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.
strategy: Supports different training strategies with aliases
as well custom training type plugins.
Expand Down Expand Up @@ -617,6 +621,7 @@ def fit(
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloader=None, # TODO: remove with 1.6
ckpt_path: Optional[str] = None,
) -> None:
r"""
Runs the full optimization routine.
Expand All @@ -630,6 +635,10 @@ def fit(
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
if train_dataloader is not None:
Expand All @@ -638,14 +647,17 @@ def fit(
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
)
train_dataloaders = train_dataloader
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")

Expand All @@ -668,7 +680,9 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self._run(model)
# TODO: ckpt_path only in v1.7
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -755,7 +769,7 @@ def _validate_impl(
)

# run validate
results = self._run(model)
results = self._run(model, ckpt_path=self.validated_ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -845,7 +859,7 @@ def _test_impl(
)

# run test
results = self._run(model)
results = self._run(model, ckpt_path=self.tested_ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -928,7 +942,7 @@ def _predict_impl(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model)
results = self._run(model, ckpt_path=self.predicted_ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -997,24 +1011,18 @@ def tune(

return result

def _restore_modules_and_callbacks(self) -> None:
if self.state.fn != TrainerFn.FITTING:
return

self.checkpoint_connector.restore_datamodule()
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self.checkpoint_connector.resume_start(checkpoint_path)
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_datamodule()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand All @@ -1031,9 +1039,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self._data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -1042,9 +1047,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)

self._call_configure_sharded_model() # allow user to setup in model sharded environment
self.accelerator.setup(self)
Expand Down Expand Up @@ -1092,16 +1096,14 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()

self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()
if self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

self.checkpoint_connector.resume_end()

# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()

Expand Down Expand Up @@ -1201,9 +1203,6 @@ def _pre_training_routine(self):
# register signals
self.signal_connector.register_signal_handlers()

if self.state.fn != TrainerFn.TUNING:
self.checkpoint_connector.resume_end()

# --------------------------
# Pre-train
# --------------------------
Expand Down Expand Up @@ -1804,7 +1803,11 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:

@property
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
return self.checkpoint_connector.resume_checkpoint_path
rank_zero_deprecation(
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead."
)
return self.checkpoint_connector.resume_from_checkpoint_fit_path

def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
Expand Down Expand Up @@ -2029,15 +2032,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)
4 changes: 2 additions & 2 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_resume_training_on_cpu(tmpdir):
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model, ckpt_path=model_path)
assert trainer.state.finished, f"Training failed with {trainer.state}"


Expand Down
Loading

0 comments on commit 2db5dd0

Please sign in to comment.