Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify checkpoint load paths [redo #9693] #10061

Merged
merged 45 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
6ecf45b
first commit wip
jjenniferdai Sep 24, 2021
5a0d60c
test_lambda_fix
jjenniferdai Sep 24, 2021
5b7df74
more test updates
jjenniferdai Sep 24, 2021
b5dee8e
updates
jjenniferdai Sep 24, 2021
2bb5bc5
resume_start doc update
jjenniferdai Sep 24, 2021
48d200b
Merge branch 'master' into unify-cp-load-paths
jjenniferdai Sep 24, 2021
ef2d2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2021
205a380
mypy
jjenniferdai Sep 24, 2021
cd5b5c0
Merge branch 'master' into unify-cp-load-paths
jjenniferdai Sep 27, 2021
ea41b41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
7d89b88
add resume_end, depr trainer.resume_checkpoint_path
jjenniferdai Sep 27, 2021
df7d4a9
fit arg order
jjenniferdai Sep 28, 2021
1f350c6
Merge branch 'master' into unify-cp-load-paths
jjenniferdai Sep 30, 2021
4315cf5
bring back properties
jjenniferdai Sep 30, 2021
d5069ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
9d2a568
Merge branch 'unify-cp-load-paths' of https://github.com/jjenniferdai…
jjenniferdai Oct 21, 2021
2b06bf4
first mergee Merge branch 'jjenniferdai-unify-cp-load-paths'
jjenniferdai Oct 21, 2021
02675f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
0ef1235
docs update
jjenniferdai Oct 21, 2021
0cc3b15
fit() docs update
jjenniferdai Oct 22, 2021
9134475
Update `` docs/source/advanced/advanced_gpu.rst
jjenniferdai Oct 22, 2021
9549ff1
Update `` docs/source/common/trainer.rst
jjenniferdai Oct 22, 2021
666411d
Update trainer init deprecation msg
jjenniferdai Oct 22, 2021
3535692
Update validate: _run(ckpt_path=) named arg
jjenniferdai Oct 22, 2021
da76397
update deprecation warning msg
jjenniferdai Oct 22, 2021
c0b0fca
Update warn msg pytorch_lightning/trainer/callback_hook.py
jjenniferdai Oct 22, 2021
dd3fa7f
all _run(ckpt_path=) named arg, doc update
jjenniferdai Oct 22, 2021
2f693ed
resume_end clear resume_from_checkpoint_fit_path as well
jjenniferdai Oct 22, 2021
eb6c438
update clear resume_from_checkpoint_fit_path for fit matching path only
jjenniferdai Oct 22, 2021
ec9b15d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2021
da4f803
mypy
jjenniferdai Oct 22, 2021
273bb2a
Update PR# CHANGELOG.md
jjenniferdai Oct 22, 2021
e5da1fd
doc Update pytorch_lightning/trainer/callback_hook.py
jjenniferdai Oct 22, 2021
4a70622
info msg, added arg changelog, tests update
jjenniferdai Oct 22, 2021
b14042e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2021
c261efa
Update CHANGELOG.md
jjenniferdai Oct 22, 2021
58393a2
shorter in statement
jjenniferdai Oct 22, 2021
89a9d3d
Update CHANGELOG.md extra line
jjenniferdai Oct 22, 2021
71546f5
Merge branch 'master' into master
jjenniferdai Oct 22, 2021
b8b6325
test updates
jjenniferdai Oct 23, 2021
918625c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2021
14eff7f
add some asserts
jjenniferdai Oct 23, 2021
4f1f39e
Merge branch 'master' into jjenniferdai/master
tchaton Oct 25, 2021
ebc4b63
Merge branch 'master' into master
jjenniferdai Oct 25, 2021
e8932a1
Merge branch 'master' into master
jjenniferdai Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))

- Added `ckpt_path` argument for `trainer.fit()` to replace `Trainer(resume_from_checkpoint=)`) ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))


### Changed

Expand Down Expand Up @@ -403,6 +405,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))


- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/advanced/advanced_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ After training using ZeRO Stage 3, you'll notice that your checkpoints are a dir

.. warning::

This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the `resume_from_checkpoint` Trainer argument. Ensure to keep the sharded checkpoint directory if this is required.
This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the ``trainer.fit(ckpt_path=)`` call. Ensure to keep the sharded checkpoint directory if this is required.

Custom DeepSpeed Config
"""""""""""""""""""""""
Expand Down
4 changes: 4 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,10 @@ By setting to False, you have to add your own distributed sampler:
resume_from_checkpoint
^^^^^^^^^^^^^^^^^^^^^^

.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.


.. raw:: html

<video width="50%" max-width="400px" controls
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ do the following:
.. code-block:: python

model = LitModel()
trainer = Trainer(resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
2 changes: 1 addition & 1 deletion docs/source/extensions/loops_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and
def on_load_checkpoint(self, state_dict):
self.iteration = state_dict["iteration"]

When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
When the Trainer is restarting from a checkpoint (e.g., through :code:`trainer.fit(ckpt_path=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint.
For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example:

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
"Be aware that when using `resume_from_checkpoint`,"
" callbacks used to create the checkpoint need to be provided."
"Be aware that when using `ckpt_path`,"
" callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
f" Please add the following callbacks: {list(difference)}.",
UserWarning,
)
Expand Down
47 changes: 29 additions & 18 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self.resume_checkpoint_path: Optional[_PATH] = None
# TODO: remove resume_from_checkpoint_fit_path in v1.7
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead."
)
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand All @@ -53,14 +60,14 @@ def hpc_resume_path(self) -> Optional[str]:
if os.path.exists(auto_save_checkpoint):
return auto_save_checkpoint

def resume_start(self) -> None:
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore
"""
self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
return
Expand All @@ -83,8 +90,22 @@ def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any
def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
released."""
if self.resume_checkpoint_path:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
if self.resume_checkpoint_path and self.trainer.state.fn is not None:
if self.trainer.state.fn == TrainerFn.FITTING:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif (
self.trainer.state.fn == TrainerFn.VALIDATING
or self.trainer.state.fn == TrainerFn.TESTING
or self.trainer.state.fn == TrainerFn.PREDICTING
):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
# TODO: remove resume_from_checkpoint_fit_path in v1.7
if (
self.trainer.state.fn is not None
and self.trainer.state.fn == TrainerFn.FITTING
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path
):
self.resume_from_checkpoint_fit_path = None
self.resume_checkpoint_path = None
self._loaded_checkpoint = {}

Expand All @@ -99,16 +120,15 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
state-restore, in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore

All restored states are listed in return value description of `dump_checkpoint`.

Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
self.resume_start(checkpoint_path)

# restore module states
self.restore_datamodule()
Expand Down Expand Up @@ -157,15 +177,6 @@ def restore_model(self) -> None:
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.

Expand Down
88 changes: 41 additions & 47 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ def __init__(
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.

.. deprecated:: v1.5
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.

strategy: Supports different training strategies with aliases
as well custom training type plugins.

Expand Down Expand Up @@ -615,6 +619,7 @@ def fit(
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloader=None, # TODO: remove with 1.6
ckpt_path: Optional[str] = None,
) -> None:
r"""
Runs the full optimization routine.
Expand All @@ -628,6 +633,10 @@ def fit(

val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.

datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
if train_dataloader is not None:
Expand All @@ -636,14 +645,17 @@ def fit(
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
)
train_dataloaders = train_dataloader
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")

Expand All @@ -666,7 +678,9 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self._run(model)
# TODO: ckpt_path only in v1.7
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -753,7 +767,7 @@ def _validate_impl(
)

# run validate
results = self._run(model)
results = self._run(model, ckpt_path=self.validated_ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -843,7 +857,7 @@ def _test_impl(
)

# run test
results = self._run(model)
results = self._run(model, ckpt_path=self.tested_ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -926,7 +940,7 @@ def _predict_impl(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model)
results = self._run(model, ckpt_path=self.predicted_ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -995,24 +1009,18 @@ def tune(

return result

def _restore_modules_and_callbacks(self) -> None:
if self.state.fn != TrainerFn.FITTING:
return

self.checkpoint_connector.restore_datamodule()
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self.checkpoint_connector.resume_start(checkpoint_path)
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_datamodule()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand All @@ -1029,9 +1037,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -1040,9 +1045,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)

self._call_configure_sharded_model() # allow user to setup in model sharded environment
self.accelerator.setup(self)
Expand Down Expand Up @@ -1090,16 +1094,14 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
# plugin will setup fitting (e.g. ddp will launch child processes)
self._pre_dispatch()

if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()

self.checkpoint_connector.resume_start()
self._restore_modules_and_callbacks()
if self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()

self.checkpoint_connector.resume_end()

# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()

Expand Down Expand Up @@ -1199,9 +1201,6 @@ def _pre_training_routine(self):
# register signals
self.signal_connector.register_signal_handlers()

if self.state.fn != TrainerFn.TUNING:
self.checkpoint_connector.resume_end()

# --------------------------
# Pre-train
# --------------------------
Expand Down Expand Up @@ -1801,7 +1800,11 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:

@property
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
return self.checkpoint_connector.resume_checkpoint_path
rank_zero_deprecation(
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead."
)
return self.checkpoint_connector.resume_from_checkpoint_fit_path

def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
Expand Down Expand Up @@ -2026,15 +2029,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)
4 changes: 2 additions & 2 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_resume_training_on_cpu(tmpdir):
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model, ckpt_path=model_path)
assert trainer.state.finished, f"Training failed with {trainer.state}"


Expand Down
Loading