From a2560f166c3c89ddeaf702302dadcc010c24c7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:32:13 +0200 Subject: [PATCH 01/29] deprecate --- CHANGELOG.md | 3 ++ .../connectors/checkpoint_connector.py | 40 +++++++------------ tests/deprecated_api/test_remove_1-4.py | 13 ++++++ tests/helpers/pipelines.py | 2 +- .../data/horovod/train_default_model.py | 2 +- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9bb747b70542..c075a28c8803e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) +- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6711ef3cb748e..dca163be14e51 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,8 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn, \ + rank_zero_deprecation from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -248,6 +249,19 @@ def restore_lr_schedulers(self) -> None: # ---------------------------------- # PRIVATE OPS # ---------------------------------- + def hpc_load(self, checkpoint_path: str): + """ + Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: + `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. + Use `CheckpointConnector.restore` instead. + """ + rank_zero_deprecation( + "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." + " Use `CheckpointConnector.restore()` instead." + ) + self.restore(checkpoint_path) + def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object @@ -365,30 +379,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.lightning_module - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) - def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 37d8abfdf905d..23df12586d328 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx): with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): trainer.fit(TestModel()) + + +def test_v1_4_0_deprecated_hpc_load(tmpdir): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger) + checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir)) + with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"): + trainer.checkpoint_connector.hpc_load(checkpoint_path) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index f7a6484f6b27e..02a9e2dd0cfb2 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -91,7 +91,7 @@ def run_model_test( trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) @torch.no_grad() diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index ed0d33f5e8c82..c4cbaeb1363c9 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None: trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) if on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) From 88f201593c02a93cb0170327da5acbc96564534a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:33:40 +0200 Subject: [PATCH 02/29] test --- .../connectors/test_callback_connector.py | 13 +++ .../connectors/test_checkpoint_connector.py | 107 ++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 tests/trainer/connectors/test_checkpoint_connector.py diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..501482d77a240 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from unittest.mock import Mock diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py new file mode 100644 index 0000000000000..03982a474d02f --- /dev/null +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -0,0 +1,107 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +def test_preloaded_checkpoint_lifecycle(tmpdir): + """ Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + + connector = trainer.checkpoint_connector + + assert not trainer.resume_from_checkpoint + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + connector.resume_start() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) + connector = trainer.checkpoint_connector + connector.resume_start() + assert connector.resume_checkpoint_path == ckpt_path + assert connector._loaded_checkpoint + assert isinstance(connector._loaded_checkpoint, dict) + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + +def test_hpc_restore_attempt(tmpdir): + """ Test that restore() attempts to restore the hpc_ckpt with highest priority. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + + hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt" + trainer.save_checkpoint(hpc_ckpt_path) + assert Path(hpc_ckpt_path).exists() + + # set weights to zero + for param in model.parameters(): + torch.nn.init.constant_(param, 0) + + # case 1: restore hpc first, no explicit resume path provided + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + torch.nn.init.constant_(param, 0) + + # case 2: explicit resume path provided, restore hpc anyway + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint="not existing") + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + + +def test_hpc_max_ckpt_version(tmpdir): + """ Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") + + assert trainer.checkpoint_connector.hpc_resume_path == tmpdir / "hpc_ckpt_33.ckpt" + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None \ No newline at end of file From b0c0b0709beb916f26b947900ac3d91466ea54d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:44:50 +0200 Subject: [PATCH 03/29] tests --- .../connectors/test_checkpoint_connector.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 03982a474d02f..fc21e132aeb31 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from pathlib import Path import torch @@ -59,12 +60,14 @@ def test_hpc_restore_attempt(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=1, + checkpoint_callback=False, + logger=False, ) trainer.fit(model) hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt" trainer.save_checkpoint(hpc_ckpt_path) - assert Path(hpc_ckpt_path).exists() + assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"] # set weights to zero for param in model.parameters(): @@ -74,6 +77,8 @@ def test_hpc_restore_attempt(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=2, + checkpoint_callback=False, + logger=False, ) trainer.fit(model) @@ -82,7 +87,11 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint="not existing") + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + resume_from_checkpoint="not existing" + ) trainer.fit(model) for param in model.parameters(): @@ -104,4 +113,4 @@ def test_hpc_max_ckpt_version(tmpdir): assert trainer.checkpoint_connector.hpc_resume_path == tmpdir / "hpc_ckpt_33.ckpt" assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 - assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None \ No newline at end of file + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None From 0f17119af300a49faeb59bf5a13f1da04e8c67e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:46:21 +0200 Subject: [PATCH 04/29] ypf --- .../trainer/connectors/checkpoint_connector.py | 9 +++++++-- tests/trainer/connectors/test_checkpoint_connector.py | 7 +------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index dca163be14e51..f203b28c09048 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,8 +21,13 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn, \ - rank_zero_deprecation +from pytorch_lightning.utilities import ( + _OMEGACONF_AVAILABLE, + DeviceType, + rank_zero_deprecation, + rank_zero_info, + rank_zero_warn, +) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index fc21e132aeb31..a8c18e126733e 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path import torch @@ -87,11 +86,7 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=3, - resume_from_checkpoint="not existing" - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing") trainer.fit(model) for param in model.parameters(): From 3aef4e456332e7f4f0322b569f012805fb0d05ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:58:59 +0200 Subject: [PATCH 05/29] all --- pl_examples/bug_report_model.py | 125 ++++++++++++------ pl_examples/model_resume.py | 33 +++++ .../plugins/training_type/deepspeed.py | 2 +- .../connectors/checkpoint_connector.py | 84 +++++++----- pytorch_lightning/trainer/trainer.py | 18 ++- tests/callbacks/test_finetuning_callback.py | 2 +- tests/trainer/test_trainer.py | 5 +- 7 files changed, 190 insertions(+), 79 deletions(-) create mode 100644 pl_examples/model_resume.py diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index abb65ba86fd93..8e2d4009c6cea 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,67 +1,106 @@ +import logging import os +from typing import Any, Dict import torch -from torch.utils.data import DataLoader, Dataset +import torch.nn as nn +import torch.optim as optim +from torch.optim import AdamW +from torch.utils.data import DataLoader -from pytorch_lightning import LightningModule, Trainer +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -class RandomDataset(Dataset): +class ToyModel(nn.Module): - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] + def __init__(self): + super().__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) - def __len__(self): - return self.len + def forward(self, x): + return self.net2(self.relu(self.net1(x))) -class BoringModel(LightningModule): +class ToyTask(pl.LightningModule): def __init__(self): super().__init__() - self.layer = torch.nn.Linear(32, 2) + self.loss_fn = nn.MSELoss() + + def setup(self, stage: str): + if stage == "test": + return + self.setup_model_and_optimizer() + print("setup called") + + def setup_model_and_optimizer(self): + self.model = ToyModel() + self.optimizer = AdamW( + self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False + ) def forward(self, x): - return self.layer(x) + return self.model(x) def training_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} + targets = self.forward(batch["model_input"]) + loss = self.loss_fn(targets, batch["label"]) - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) + # Log loss results per train step and per epoch + self.log("loss", loss) - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) + # Tell Lightning to minimize loss + return loss def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - - -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, + return self.optimizer + + # def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # self.setup_model_and_optimizer() + + +if __name__ == "__main__": + task = ToyTask() + + dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] + + train_dataloader = DataLoader(dataset, batch_size=None) + val_dataloader = DataLoader(dataset, batch_size=None) + + model_checkpoint = ModelCheckpoint( + save_last=True, + every_n_val_epochs=1, ) - trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) - trainer.test(model, test_dataloaders=test_data) + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=3, + progress_bar_refresh_rate=100, + log_gpu_memory=None, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + ) + + results = trainer.fit(task, train_dataloader) -if __name__ == '__main__': - run() + print(model_checkpoint.last_model_path) + + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=4, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + resume_from_checkpoint=model_checkpoint.last_model_path, + ) + trainer.fit(task, train_dataloader) diff --git a/pl_examples/model_resume.py b/pl_examples/model_resume.py new file mode 100644 index 0000000000000..f56e9750105d8 --- /dev/null +++ b/pl_examples/model_resume.py @@ -0,0 +1,33 @@ +import torch +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +from pl_examples.bug_report_model import ToyTask +from pytorch_lightning.callbacks import ModelCheckpoint + +if __name__ == "__main__": + task = ToyTask() + + dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] + + train_dataloader = DataLoader(dataset, batch_size=None) + val_dataloader = DataLoader(dataset, batch_size=None) + + model_checkpoint = ModelCheckpoint( + save_last=True, + every_n_val_epochs=1, + ) + + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=4, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + resume_from_checkpoint= + "/home/adrian/repositories/pytorch-lightning/lightning_logs/version_82/checkpoints/last.ckpt", + ) + trainer.fit(task, train_dataloader) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index dc688de65cd34..5ccbad575ae15 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -247,7 +247,7 @@ def _load_config(self, config): config = json.load(f) return config - def pre_dispatch(self): + def pre_dispatch(self) -> None: self.init_deepspeed() self.barrier() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f203b28c09048..128c5501b79da 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -14,6 +14,7 @@ import os import re +from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, Optional, Union @@ -29,7 +30,6 @@ rank_zero_warn, ) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -53,6 +53,13 @@ def hpc_resume_path(self) -> Optional[str]: if max_version is not None: return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" + def resume_from_checkpoint(self, path: Union[str, Path], **kwargs) -> None: + """ + Signals the Trainer to resume from the given path the next time Trainer.fit/validate/test/predict is called. + """ + self.resume_checkpoint_path = path + # TODO: decide what to resume + def resume_start(self) -> None: """ Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: @@ -94,9 +101,18 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: + # TODO: decice if we should use it or not (e.g., in Trainer.fit over self._run()) + @contextmanager + def restore_ctx(self): + try: + self.resume_start() + yield + finally: + self.resume_end() + + def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """ - Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file + Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: 1. from HPC weights if found @@ -104,39 +120,50 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: 3. don't restore All restored states are listed in return value description of `dump_checkpoint`. + + Args: + checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path + self.resume_checkpoint_path = checkpoint_path self.resume_start() - model = self.trainer.lightning_module - self.restore_model_state(model, self._loaded_checkpoint) + # restore module states + self.restore_datamodule() + self.restore_model() - if self.trainer._device_type == DeviceType.GPU: - model.cuda(self.trainer.root_gpu) + # restore callback states + self.restore_callbacks() # restore training state - if self._loaded_checkpoint: - self.restore_training_state(self._loaded_checkpoint) - + self.restore_training_state() self.resume_end() - return True - def restore_model_state(self, model: LightningModule, checkpoint) -> None: + def restore_datamodule(self) -> None: + """ Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """ + datamodule = self.trainer.datamodule + if datamodule is not None: + datamodule.on_load_checkpoint(self._loaded_checkpoint) + + def restore_model(self) -> None: """ - Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give + the LightningModule a chance to modify the contents, then finally the model gets updated with + the loaded weights. """ - if not checkpoint: + if not self._loaded_checkpoint: return - # restore datamodule states - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_load_checkpoint(checkpoint) + model = self.trainer.lightning_module # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(checkpoint) + model.on_load_checkpoint(self._loaded_checkpoint) + + # call hpc specific hook + if self.hpc_resume_path is not None: + self.trainer.lightning_module.on_hpc_load(self._loaded_checkpoint) # restore model state_dict - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) + self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """ Restore only the model weights. """ @@ -147,19 +174,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) - def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: + def restore_training_state(self) -> None: """ Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. """ - if not checkpoint: + if not self._loaded_checkpoint: return # restore precision plugin (scaler etc.) - self.trainer.precision_plugin.on_load_checkpoint(checkpoint) - - self.restore_callbacks() - + self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) # restore progress (loops etc.) self.restore_progress() @@ -229,10 +253,8 @@ def restore_optimizers(self) -> None: return # restore the optimizers - optimizer_states = self._loaded_checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - + self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint) + for optimizer in self.trainer.optimizers: # move optimizer to GPU 1 weight at a time # avoids OOM if self.trainer.root_gpu is not None: @@ -257,6 +279,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. @@ -364,6 +387,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers + # dump amp scaling self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a6f93d9b4263d..85889e2ddc9ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -456,6 +456,9 @@ def fit( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) + self.checkpoint_connector.resume_start() + + # with self.checkpoint_connector.restore_ctx(): self._run(model) assert self.state.stopped @@ -732,7 +735,14 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED self.call_hook("on_before_accelerator_backend_setup", model) self.accelerator.connect(model) self.accelerator.setup_environment() - self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self._call_setup_hook(model) # allow user to setup lightning_module in accelerator + + # restore modules after setup + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -806,6 +816,9 @@ def _pre_dispatch(self): self.logger.log_graph(self.lightning_module) self.logger.save() + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + def _post_dispatch(self): self.accelerator.post_dispatch(self) self.accelerator.teardown() @@ -849,8 +862,7 @@ def _pre_training_routine(self): if self.is_global_zero and self.weights_summary is not None and not self.testing: ref_model.summarize(mode=self.weights_summary) - # restore training and model before hpc is called - self.checkpoint_connector.restore() + self.checkpoint_connector.resume_end() # on pretrain routine end self.on_pretrain_routine_end() diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 53d34c4645bef..fe8915e6e8443 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -275,7 +275,7 @@ def configure_optimizers(self): model = FreezeModel() cb = OnEpochLayerFinetuning() trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) - with pytest.raises(IndexError, match="index 6 is out of range"): + with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"): trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d353c0941d3f6..c04191a57bfa8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -428,7 +428,10 @@ def test_model_checkpoint_only_weights(tmpdir): # assert restoring train state fails with pytest.raises(KeyError, match="checkpoint contains only the model"): - trainer.checkpoint_connector.restore(new_weights_path) + trainer.checkpoint_connector.resume_from_checkpoint(new_weights_path) + trainer.checkpoint_connector.resume_start() + trainer.checkpoint_connector.restore_training_state() + trainer.checkpoint_connector.resume_end() def test_model_freeze_unfreeze(): From 3cc54b8c99a993397eeb15684f80386ec4cfdfcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:09:34 +0200 Subject: [PATCH 06/29] clean up --- pl_examples/bug_report_model.py | 125 ++++++------------ pl_examples/model_resume.py | 33 ----- .../plugins/training_type/deepspeed.py | 2 +- tests/trainer/test_trainer.py | 5 +- 4 files changed, 45 insertions(+), 120 deletions(-) delete mode 100644 pl_examples/model_resume.py diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 8e2d4009c6cea..abb65ba86fd93 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,106 +1,67 @@ -import logging import os -from typing import Any, Dict import torch -import torch.nn as nn -import torch.optim as optim -from torch.optim import AdamW -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset -import pytorch_lightning as pl -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning import LightningModule, Trainer -class ToyModel(nn.Module): +class RandomDataset(Dataset): - def __init__(self): - super().__init__() - self.net1 = nn.Linear(10, 10) - self.relu = nn.ReLU() - self.net2 = nn.Linear(10, 5) + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) - def forward(self, x): - return self.net2(self.relu(self.net1(x))) + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len -class ToyTask(pl.LightningModule): +class BoringModel(LightningModule): def __init__(self): super().__init__() - self.loss_fn = nn.MSELoss() - - def setup(self, stage: str): - if stage == "test": - return - self.setup_model_and_optimizer() - print("setup called") - - def setup_model_and_optimizer(self): - self.model = ToyModel() - self.optimizer = AdamW( - self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False - ) + self.layer = torch.nn.Linear(32, 2) def forward(self, x): - return self.model(x) + return self.layer(x) def training_step(self, batch, batch_idx): - targets = self.forward(batch["model_input"]) - loss = self.loss_fn(targets, batch["label"]) + loss = self(batch).sum() + self.log("train_loss", loss) + return {"loss": loss} - # Log loss results per train step and per epoch - self.log("loss", loss) + def validation_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("valid_loss", loss) - # Tell Lightning to minimize loss - return loss + def test_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("test_loss", loss) def configure_optimizers(self): - return self.optimizer - - # def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # self.setup_model_and_optimizer() - - -if __name__ == "__main__": - task = ToyTask() - - dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] - - train_dataloader = DataLoader(dataset, batch_size=None) - val_dataloader = DataLoader(dataset, batch_size=None) - - model_checkpoint = ModelCheckpoint( - save_last=True, - every_n_val_epochs=1, + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def run(): + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, ) + trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) + trainer.test(model, test_dataloaders=test_data) - trainer = pl.Trainer( - gpus=2, - precision=16, - max_epochs=3, - progress_bar_refresh_rate=100, - log_gpu_memory=None, - reload_dataloaders_every_epoch=True, - limit_train_batches=10, - limit_val_batches=10, - limit_test_batches=10, - callbacks=[model_checkpoint], - ) - - results = trainer.fit(task, train_dataloader) - print(model_checkpoint.last_model_path) - - trainer = pl.Trainer( - gpus=2, - precision=16, - max_epochs=4, - reload_dataloaders_every_epoch=True, - limit_train_batches=10, - limit_val_batches=10, - limit_test_batches=10, - callbacks=[model_checkpoint], - resume_from_checkpoint=model_checkpoint.last_model_path, - ) - trainer.fit(task, train_dataloader) +if __name__ == '__main__': + run() diff --git a/pl_examples/model_resume.py b/pl_examples/model_resume.py deleted file mode 100644 index f56e9750105d8..0000000000000 --- a/pl_examples/model_resume.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from torch.utils.data import DataLoader - -import pytorch_lightning as pl -from pl_examples.bug_report_model import ToyTask -from pytorch_lightning.callbacks import ModelCheckpoint - -if __name__ == "__main__": - task = ToyTask() - - dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] - - train_dataloader = DataLoader(dataset, batch_size=None) - val_dataloader = DataLoader(dataset, batch_size=None) - - model_checkpoint = ModelCheckpoint( - save_last=True, - every_n_val_epochs=1, - ) - - trainer = pl.Trainer( - gpus=2, - precision=16, - max_epochs=4, - reload_dataloaders_every_epoch=True, - limit_train_batches=10, - limit_val_batches=10, - limit_test_batches=10, - callbacks=[model_checkpoint], - resume_from_checkpoint= - "/home/adrian/repositories/pytorch-lightning/lightning_logs/version_82/checkpoints/last.ckpt", - ) - trainer.fit(task, train_dataloader) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 5ccbad575ae15..dc688de65cd34 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -247,7 +247,7 @@ def _load_config(self, config): config = json.load(f) return config - def pre_dispatch(self) -> None: + def pre_dispatch(self): self.init_deepspeed() self.barrier() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c04191a57bfa8..d353c0941d3f6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -428,10 +428,7 @@ def test_model_checkpoint_only_weights(tmpdir): # assert restoring train state fails with pytest.raises(KeyError, match="checkpoint contains only the model"): - trainer.checkpoint_connector.resume_from_checkpoint(new_weights_path) - trainer.checkpoint_connector.resume_start() - trainer.checkpoint_connector.restore_training_state() - trainer.checkpoint_connector.resume_end() + trainer.checkpoint_connector.restore(new_weights_path) def test_model_freeze_unfreeze(): From 0fa9807f855095e7761d3dc196e043804b0646c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:14:00 +0200 Subject: [PATCH 07/29] clean up --- .../connectors/checkpoint_connector.py | 20 +------------------ pytorch_lightning/trainer/trainer.py | 3 +-- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 128c5501b79da..244ae8dbb3608 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -14,9 +14,8 @@ import os import re -from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Optional, Union import torch @@ -53,13 +52,6 @@ def hpc_resume_path(self) -> Optional[str]: if max_version is not None: return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" - def resume_from_checkpoint(self, path: Union[str, Path], **kwargs) -> None: - """ - Signals the Trainer to resume from the given path the next time Trainer.fit/validate/test/predict is called. - """ - self.resume_checkpoint_path = path - # TODO: decide what to resume - def resume_start(self) -> None: """ Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: @@ -101,15 +93,6 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - # TODO: decice if we should use it or not (e.g., in Trainer.fit over self._run()) - @contextmanager - def restore_ctx(self): - try: - self.resume_start() - yield - finally: - self.resume_end() - def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """ Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file @@ -387,7 +370,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers - # dump amp scaling self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 85889e2ddc9ca..ccc620b7c95af 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -458,7 +458,6 @@ def fit( self.checkpoint_connector.resume_start() - # with self.checkpoint_connector.restore_ctx(): self._run(model) assert self.state.stopped @@ -735,7 +734,7 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED self.call_hook("on_before_accelerator_backend_setup", model) self.accelerator.connect(model) self.accelerator.setup_environment() - self._call_setup_hook(model) # allow user to setup lightning_module in accelerator + self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment # restore modules after setup self.checkpoint_connector.restore_datamodule() From f62cd51bbadf44ff286f553b6655eca4c2af9b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:35:13 +0200 Subject: [PATCH 08/29] test hook calls --- .../connectors/checkpoint_connector.py | 5 +++ .../connectors/test_checkpoint_connector.py | 41 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f203b28c09048..9581f9e76be20 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -135,6 +135,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None: # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) + # call hpc specific hook + if self.hpc_resume_path is not None: + model.on_hpc_load(self._loaded_checkpoint) + # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(checkpoint) @@ -366,6 +370,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: self.trainer.precision_plugin.on_save_checkpoint(checkpoint) + # dump hyper-parameters # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index a8c18e126733e..c2e3d56791cf2 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -19,6 +19,47 @@ from tests.helpers import BoringModel +class HPCHookdedModel(BoringModel): + + def __init__(self): + super().__init__() + self.hpc_save_called = 0 + self.hpc_load_called = 0 + + def on_hpc_save(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_save_called += 1 + + def on_hpc_load(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_load_called += 1 + + +def test_hpc_hook_calls(tmpdir): + model = HPCHookdedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + ) + trainer.fit(model) + connector = trainer.checkpoint_connector + connector.hpc_save(tmpdir, logger=trainer.logger) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 0 + + # new training run, restore from hpc checkpoint file automatically + assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt", "lightning_logs"} + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + ) + trainer.fit(model) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 1 + + def test_preloaded_checkpoint_lifecycle(tmpdir): """ Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """ model = BoringModel() From 09dd67dbfd90308f8732a637b04a89351337ed89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:35:53 +0200 Subject: [PATCH 09/29] space --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9581f9e76be20..ed00cc6c9b7ca 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -261,6 +261,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. From 1dea5be6c3d1c2efebd116595a754e2f4df335b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Jun 2021 13:37:19 +0000 Subject: [PATCH 10/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ed00cc6c9b7ca..24424f2c19201 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -261,7 +261,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. - + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. From 1a2c6e4dafa4c18fbbc3f7bb3c85013689619588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:54:26 +0200 Subject: [PATCH 11/29] unused import --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 24424f2c19201..ed3c97dde80cb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -29,7 +29,6 @@ rank_zero_warn, ) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS From af256250424e50176bd143e1f043648b9bea918a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 17:30:10 +0200 Subject: [PATCH 12/29] fix info message --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9e100715726a8..9bdeaf7dd83dd 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -82,7 +82,8 @@ def resume_start(self) -> None: def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ - rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") + if self.resume_checkpoint_path: + rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None self._loaded_checkpoint = dict() From de6e6d9eb57404b69c8cc04306920423cf748a04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 19:32:26 +0200 Subject: [PATCH 13/29] move --- pytorch_lightning/trainer/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ccc620b7c95af..5d2e09e9fce90 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -456,8 +456,6 @@ def fit( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self.checkpoint_connector.resume_start() - self._run(model) assert self.state.stopped @@ -724,6 +722,8 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) + self.checkpoint_connector.resume_start() + # hook self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) @@ -808,6 +808,9 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED def _pre_dispatch(self): self.accelerator.pre_dispatch(self) + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) @@ -815,9 +818,6 @@ def _pre_dispatch(self): self.logger.log_graph(self.lightning_module) self.logger.save() - # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() - def _post_dispatch(self): self.accelerator.post_dispatch(self) self.accelerator.teardown() From d89a98e894d2852f508deab51bb79460abce13d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 19:38:34 +0200 Subject: [PATCH 14/29] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c075a28c8803e..996e4b38af15f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) +- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()` but before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) + + ### Deprecated From 777a297c95aa23ceaf2ce785f419788f893f72c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 19:39:46 +0200 Subject: [PATCH 15/29] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 996e4b38af15f..ba72f4eb4932f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -169,7 +169,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) -- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()` but before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) +- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) ### Deprecated From ce98239cb84f50b493bb682425724b64deff9b00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 13:11:40 +0200 Subject: [PATCH 16/29] test moving resume end after pre_dispatch --- pytorch_lightning/trainer/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5d2e09e9fce90..d56905507de7f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -783,6 +783,10 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # plugin will setup fitting (e.g. ddp will launch child processes) self._pre_dispatch() + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + self.checkpoint_connector.resume_end() + # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -808,9 +812,6 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED def _pre_dispatch(self): self.accelerator.pre_dispatch(self) - # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() - # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) @@ -861,8 +862,6 @@ def _pre_training_routine(self): if self.is_global_zero and self.weights_summary is not None and not self.testing: ref_model.summarize(mode=self.weights_summary) - self.checkpoint_connector.resume_end() - # on pretrain routine end self.on_pretrain_routine_end() ref_model.on_pretrain_routine_end() From 492856c7984436b3dc3222079804d45abc00e8bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 15:43:41 +0200 Subject: [PATCH 17/29] wip --- .../trainer/connectors/checkpoint_connector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 87f12b223298e..c0522ad7875b6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -63,6 +63,9 @@ def resume_start(self) -> None: Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ + if not self.trainer.training: + return + self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: @@ -82,6 +85,9 @@ def resume_start(self) -> None: def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ + if not self.trainer.training: + return + if self.resume_checkpoint_path: rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None From 71c74fe6fe3d03e39acb29138bf69127ade69ed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Jun 2021 22:29:32 +0200 Subject: [PATCH 18/29] move --- pytorch_lightning/trainer/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d56905507de7f..1cd0e5196536a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -456,6 +456,8 @@ def fit( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) + self.checkpoint_connector.resume_start() + self._run(model) assert self.state.stopped @@ -722,8 +724,6 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) - self.checkpoint_connector.resume_start() - # hook self.data_connector.prepare_data(model) self.callback_connector._attach_model_callbacks(model, self) @@ -785,7 +785,6 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED # restore optimizers, etc. self.checkpoint_connector.restore_training_state() - self.checkpoint_connector.resume_end() # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -812,6 +811,8 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED def _pre_dispatch(self): self.accelerator.pre_dispatch(self) + self.checkpoint_connector.resume_end() + # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) From 0f888971060d4c7f40feb19f585b1faab16f33e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 14 Jun 2021 15:43:41 +0200 Subject: [PATCH 19/29] Revert "wip" This reverts commit 492856c7984436b3dc3222079804d45abc00e8bf. --- .../trainer/connectors/checkpoint_connector.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c0522ad7875b6..87f12b223298e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -63,9 +63,6 @@ def resume_start(self) -> None: Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ - if not self.trainer.training: - return - self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: @@ -85,9 +82,6 @@ def resume_start(self) -> None: def resume_end(self) -> None: """ Signal the connector that all states have resumed and memory for the checkpoint object can be released. """ - if not self.trainer.training: - return - if self.resume_checkpoint_path: rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None From 1480442ad1fa3bcece418f24d197d60813addc31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 02:04:31 +0200 Subject: [PATCH 20/29] move misplaced resume_end() --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1cd0e5196536a..92058f30845a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -811,8 +811,6 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED def _pre_dispatch(self): self.accelerator.pre_dispatch(self) - self.checkpoint_connector.resume_end() - # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) @@ -850,6 +848,8 @@ def _pre_training_routine(self): # register auto-resubmit when on SLURM self.slurm_connector.register_slurm_signal_handlers() + self.checkpoint_connector.resume_end() + # -------------------------- # Pre-train # -------------------------- From a987f1d83b501d6c852338ced28aa200d41b2d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 02:47:28 +0200 Subject: [PATCH 21/29] add guard to restore_datamodule --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 87f12b223298e..a426c3d736982 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -124,6 +124,9 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: def restore_datamodule(self) -> None: """ Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """ + if not self._loaded_checkpoint: + return + datamodule = self.trainer.datamodule if datamodule is not None: datamodule.on_load_checkpoint(self._loaded_checkpoint) From c8ef69316a9621b963f7b2702b90a3fe9076464b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 09:47:28 +0200 Subject: [PATCH 22/29] rm duplicate comment --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a426c3d736982..b599caf91e20d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -363,7 +363,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: self.trainer.precision_plugin.on_save_checkpoint(checkpoint) - # dump hyper-parameters # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): From 6a38c9bfdde672cd5b203939ee98991c5ef5c25a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 14:19:37 +0200 Subject: [PATCH 23/29] add hook test --- tests/models/test_hooks.py | 48 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6413ca8c930bd..c22bcd6e9385d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -500,6 +500,54 @@ def test_trainer_model_hook_system_test(tmpdir): assert called == expected +def test_trainer_model_hook_system_fit_resume(tmpdir): + # initial training to get a checkpoint + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=0, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(model) + best_model_path = trainer.checkpoint_callback.best_model_path + # resume from checkpoint with HookedModel + + called = [] + model = HookedModel(called) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=0, + progress_bar_refresh_rate=0, + weights_summary=None, + resume_from_checkpoint=best_model_path, + ) + assert called == [] + trainer.fit(model) + expected = [ + 'prepare_data', + 'configure_callbacks', + 'setup', + 'on_load_checkpoint', + 'configure_sharded_model', + 'configure_optimizers', + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'train', + 'on_train_dataloader', + 'train_dataloader', + # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` + 'on_val_dataloader', + 'val_dataloader', + 'on_fit_end', + 'teardown', + ] + assert called == expected + + def test_hooks_with_different_argument_names(tmpdir): """ Test that argument names can be anything in the hooks From d208b5c41b94652d2d8f47e689a98b2f5c3ed65a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 14:22:44 +0200 Subject: [PATCH 24/29] comment --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index c22bcd6e9385d..353a3c91f7ece 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -512,8 +512,8 @@ def test_trainer_model_hook_system_fit_resume(tmpdir): ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path + # resume from checkpoint with HookedModel - called = [] model = HookedModel(called) trainer = Trainer( From 25918e78ab5c44d31a94434b9a9901a432ff329d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Jun 2021 12:24:06 +0000 Subject: [PATCH 25/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 353a3c91f7ece..5d64afa9affda 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -512,7 +512,7 @@ def test_trainer_model_hook_system_fit_resume(tmpdir): ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path - + # resume from checkpoint with HookedModel called = [] model = HookedModel(called) From 3d9539da094681536663171c7c157da24f41ed57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 14:24:16 +0200 Subject: [PATCH 26/29] blank --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 353a3c91f7ece..5d64afa9affda 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -512,7 +512,7 @@ def test_trainer_model_hook_system_fit_resume(tmpdir): ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path - + # resume from checkpoint with HookedModel called = [] model = HookedModel(called) From b45c3355709de8206107c4d5c9bbf99fde1f11de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 15:43:17 +0200 Subject: [PATCH 27/29] merge tests --- tests/models/test_hooks.py | 70 ++++++++++---------------------------- 1 file changed, 18 insertions(+), 52 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5d64afa9affda..5c6dde1c5829b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -377,17 +377,30 @@ def test_trainer_model_hook_system_fit(tmpdir): assert called == expected -def test_trainer_model_hook_system_fit_no_val(tmpdir): +def test_trainer_model_hook_system_fit_with_resume(tmpdir): + # initial training to get a checkpoint + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=0, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(model) + best_model_path = trainer.checkpoint_callback.best_model_path + + # resume from checkpoint with HookedModel called = [] model = HookedModel(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + max_steps=2, limit_val_batches=0, - limit_train_batches=train_batches, progress_bar_refresh_rate=0, weights_summary=None, + resume_from_checkpoint=best_model_path, ) assert called == [] trainer.fit(model) @@ -395,6 +408,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'prepare_data', 'configure_callbacks', 'setup', + 'on_load_checkpoint', 'configure_sharded_model', 'configure_optimizers', 'on_fit_start', @@ -409,7 +423,7 @@ def test_trainer_model_hook_system_fit_no_val(tmpdir): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(HookedModel._train_batch() * train_batches), + *(HookedModel._train_batch() * (train_batches - 1)), 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', @@ -500,54 +514,6 @@ def test_trainer_model_hook_system_test(tmpdir): assert called == expected -def test_trainer_model_hook_system_fit_resume(tmpdir): - # initial training to get a checkpoint - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - limit_val_batches=0, - progress_bar_refresh_rate=0, - weights_summary=None, - ) - trainer.fit(model) - best_model_path = trainer.checkpoint_callback.best_model_path - - # resume from checkpoint with HookedModel - called = [] - model = HookedModel(called) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - limit_val_batches=0, - progress_bar_refresh_rate=0, - weights_summary=None, - resume_from_checkpoint=best_model_path, - ) - assert called == [] - trainer.fit(model) - expected = [ - 'prepare_data', - 'configure_callbacks', - 'setup', - 'on_load_checkpoint', - 'configure_sharded_model', - 'configure_optimizers', - 'on_fit_start', - 'on_pretrain_routine_start', - 'on_pretrain_routine_end', - 'train', - 'on_train_dataloader', - 'train_dataloader', - # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` - 'on_val_dataloader', - 'val_dataloader', - 'on_fit_end', - 'teardown', - ] - assert called == expected - - def test_hooks_with_different_argument_names(tmpdir): """ Test that argument names can be anything in the hooks From 7e38debd1a3cb9fe84b11a9805e1eee271e33a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 15:50:28 +0200 Subject: [PATCH 28/29] clarify how many batches need to run --- tests/models/test_hooks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 58f03721c82b2..c019038794dc5 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -437,7 +437,8 @@ def test_trainer_model_hook_system_fit_with_resume(tmpdir): train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, - max_steps=2, + # already performed 1 step, now resuming to do an additional 2 + max_steps=(1 + train_batches), limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None, @@ -464,7 +465,7 @@ def test_trainer_model_hook_system_fit_with_resume(tmpdir): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - *(HookedModel._train_batch() * (train_batches - 1)), + *(HookedModel._train_batch() * train_batches), 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', From 02c3d54f3ba01cf4984fe2c4b6b443460ad556f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Jun 2021 09:49:45 +0200 Subject: [PATCH 29/29] Update tests/models/test_hooks.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index c019038794dc5..9938f756bf3fa 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -418,7 +418,7 @@ def test_trainer_model_hook_system_fit(tmpdir): assert called == expected -def test_trainer_model_hook_system_fit_with_resume(tmpdir): +def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer(