From f2f685872fcbb41c9aa6211ebb06caea01207ece Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Jul 2021 14:23:16 +0200 Subject: [PATCH 01/29] wip --- pytorch_lightning/loops/base.py | 42 +++++- .../connectors/checkpoint_connector.py | 34 ++++- .../connectors/logger_connector/result.py | 10 ++ tests/checkpointing/test_model_checkpoint.py | 130 ++++++++++++++++++ 4 files changed, 210 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8219e2ad02fc8..5ac8ed78d3289 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,8 +16,10 @@ from typing import Any, Dict, Optional from deprecate import void +from torchmetrics import Metric import pytorch_lightning as pl +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -177,21 +179,55 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[prefix + k] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, prefix + k + ".") + elif isinstance(v, ResultCollection): + # sync / unsync metrics + v.sync() + destination[prefix + k] = v.state_dict() + v.unsync() return destination - def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None: + def load_state_dict( + self, + state_dict: Dict, + prefix: str = "", + restart_progress: bool = True, + metrics: Optional[Dict[str, Metric]] = None, + ) -> None: """Loads the state of this loop and all its children.""" - self._load_from_state_dict(state_dict.copy(), prefix, restart_progress) + self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics) for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None + ) -> None: for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) + + elif isinstance(v, ResultCollection): + if isinstance(self.trainer, pl.Trainer) and getattr(self.trainer, "lightning_module", None) is not None: + metric_attributes = { + name: module + for name, module in self.trainer.lightning_module.named_modules() + if isinstance(module, Metric) + } + if metrics: + metric_attributes.update(metrics) + + # re-attach metrics references + v.load_state_dict( + state_dict[prefix + k], + metrics=metric_attributes, + sync_fn=self.trainer.training_type_plugin.reduce, + ) + + if not self.trainer.is_global_zero: + v.reset(metrics=False) + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 611946fd53dae..c8779e8db893d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,9 +15,10 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch +from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn @@ -35,6 +36,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint self._loaded_checkpoint = {} + self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -141,6 +143,12 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + # reset state on non-rank 0 + if not self.trainer.is_global_zero: + for module in self.trainer.lightning_module.modules(): + if isinstance(module, Metric): + module.reset() + def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """Restore only the model weights.""" checkpoint = self._loaded_checkpoint @@ -338,7 +346,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "epoch": current_epoch, "global_step": global_step, "pytorch-lightning_version": pl.__version__, - "state_dict": self.trainer.accelerator.lightning_module_state_dict(), + "state_dict": self._get_lightning_module_state_dict(), } if _fault_tolerant_enabled(): checkpoint["loops"] = self._get_loops_state_dict() @@ -440,7 +448,27 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: _checkpoint = self.dump_checkpoint(weights_only) self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) - def _get_loops_state_dict(self): + def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: + if _fault_tolerant_enabled(): + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + if self._persistent_metrics: + module.persistent(True) + if not module._is_synced: + module.sync() + self._persistent_metrics = True + + state_dict = self.trainer.accelerator.lightning_module_state_dict() + + if _fault_tolerant_enabled(): + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + if module._is_synced: + module.unsync() + + return state_dict + + def _get_loops_state_dict(self) -> Dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index aab86976fe76f..afaa26d0e85ba 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -604,6 +604,16 @@ def cpu(self) -> "ResultCollection": """Move all data to CPU.""" return self.to(device="cpu") + def sync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor and not result_metric._is_synced: + result_metric.sync() + + def unsync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor and result_metric._is_synced: + result_metric.unsync() + def __str__(self) -> str: return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})" diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0906ed3820705..f46d70fdf713b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -18,6 +18,7 @@ import re import time from argparse import Namespace +from contextlib import suppress from datetime import timedelta from logging import INFO from pathlib import Path @@ -31,12 +32,14 @@ import yaml from omegaconf import Container, OmegaConf from torch import optim +from torchmetrics import Metric import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -1220,3 +1223,130 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) + + +class DummyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, increment): + self.sum += increment + self.count += 1 + + def compute(self): + return self.sum // self.count + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" + + +def result_collection_reload(trainer_kwargs): + num_processes = trainer_kwargs.get("gpus", 1) + + class CustomException(Exception): + pass + + class ExtendedBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.has_reloaded = False + self.breaking_batch_idx = 3 + self.has_validated_sum = False + self.dummy_metric = DummyMetric() + self.dummy_metric_dynamic = DummyMetric() + + def training_step(self, batch, batch_idx): + assert len(batch) == 1 + if self.has_reloaded: + if batch_idx >= self.breaking_batch_idx: + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking_metric"].value + value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 + else: + if self.trainer.current_epoch == 2: + return + if batch_idx == self.breaking_batch_idx: + # simulate failure mid epoch + raise CustomException + + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking"].value + assert value == sum(range(batch_idx + 1)) + + value = self.trainer.fit_loop._results["training_step.tracking_2"] + assert value == sum(range(batch_idx + 1)) + + return super().training_step(batch, batch_idx) + + def on_epoch_end(self) -> None: + if self.trainer.current_epoch: + total = sum(range(5)) * num_processes + metrics = self.trainer.fit_loop._results.metrics(on_step=False) + assert self.trainer.fit_loop._results["training_step.tracking"].value == total + assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total + assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 + self.has_validated_sum = True + + model = ExtendedBoringModel() + + trainer = Trainer(**trainer_kwargs) + + with suppress(CustomException): + trainer.fit(model) + + checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt")) + trainer.save_checkpoint(checkpoint_path) + + trainer.accelerator.barrier() + + if trainer.is_global_zero: + checkpoint = torch.load(checkpoint_path) + assert checkpoint["state_dict"]["dummy_metric.sum"] == 3 * num_processes + + trainer_kwargs["resume_from_checkpoint"] = checkpoint_path + trainer_kwargs["max_epochs"] = 2 + + trainer = Trainer(**trainer_kwargs) + model.has_reloaded = True + trainer.fit(model) + assert model.has_validated_sum + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload(tmpdir): + result_collection_reload( + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + ) + + +@RunIf(min_gpus=2, special=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload_2_gpus(tmpdir): + result_collection_reload( + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 2, + } + ) From 4307a055a7d9064e5532b645c6c8c749af7d097b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Jul 2021 16:54:18 +0200 Subject: [PATCH 02/29] resolve some issues --- tests/checkpointing/test_model_checkpoint.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f46d70fdf713b..ddd6ed9f45df8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1255,12 +1255,21 @@ def __init__(self): self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMetric() - self.dummy_metric_dynamic = DummyMetric() def training_step(self, batch, batch_idx): assert len(batch) == 1 if self.has_reloaded: - if batch_idx >= self.breaking_batch_idx: + # hack as the state is being reset on epoch end + if batch_idx == 3 and self.current_epoch == 0: + self.metric_state = self.trainer.fit_loop._results[ + "training_step.tracking_metric" + ].value.state_dict() + if batch_idx == 0 and self.current_epoch == 1: + self.trainer.fit_loop._results["training_step.tracking_metric"].value.load_state_dict( + self.metric_state + ) + + if batch_idx >= self.breaking_batch_idx and self.current_epoch == 1: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) From e63c5601928c4a5bd9fd74339efa0517e9259882 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 30 Jul 2021 07:15:06 -0400 Subject: [PATCH 03/29] add ResultCollection --- .../plugins/training_type/ddp.py | 2 +- .../connectors/checkpoint_connector.py | 22 ++++------ .../connectors/logger_connector/result.py | 3 +- pytorch_lightning/trainer/trainer.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 43 +++++++------------ 5 files changed, 28 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a44384e18edcd..86afeefd4399f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -411,7 +411,7 @@ def _share_information_to_prevent_deadlock(self): self._share_pids() # remove `PL_DDP_SYNC_TMPDIR` from os.environ - self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None) + self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR") def _share_pids(self): """ diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c8779e8db893d..aa50491c3ae5d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -36,7 +36,6 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint self._loaded_checkpoint = {} - self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -220,7 +219,7 @@ def restore_loops(self) -> None: " consider using an end of epoch checkpoint." ) - state_dict = self._loaded_checkpoint.get("loops") + state_dict = self._loaded_checkpoint.pop("loops", None) if state_dict: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -450,21 +449,18 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: if _fault_tolerant_enabled(): - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - if self._persistent_metrics: - module.persistent(True) - if not module._is_synced: - module.sync() - self._persistent_metrics = True + metrics = [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] + for metric in metrics: + metric.persistent(True) + if not metric._is_synced: + metric.sync() state_dict = self.trainer.accelerator.lightning_module_state_dict() if _fault_tolerant_enabled(): - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - if module._is_synced: - module.unsync() + for metric in metrics: + if metric._is_synced: + metric.unsync() return state_dict diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index afaa26d0e85ba..e71145aa29332 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -251,13 +251,14 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - skip = ["update", "compute", "_update_signature"] + skip = ["update", "compute", "_update_signature", "_cache"] if not self.is_tensor and drop_value: # Avoid serializing ResultMetrics which are passed Metrics skip.append("value") d = {k: v for k, v in self.__dict__.items() if k not in skip} d["meta"] = d["meta"].__getstate__() d["_class"] = self.__class__.__name__ + d["_is_synced"] = False # don't consider the state as synced on reload return d def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6bb9263620245..8a6c0a537a6c4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1310,7 +1310,7 @@ def _log_device_info(self) -> None: ) def _on_expection(self): - if not self.is_global_zero or not _fault_tolerant_enabled(): + if not _fault_tolerant_enabled(): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ddd6ed9f45df8..b8092f591abb3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1259,33 +1259,20 @@ def __init__(self): def training_step(self, batch, batch_idx): assert len(batch) == 1 if self.has_reloaded: - # hack as the state is being reset on epoch end - if batch_idx == 3 and self.current_epoch == 0: - self.metric_state = self.trainer.fit_loop._results[ - "training_step.tracking_metric" - ].value.state_dict() - if batch_idx == 0 and self.current_epoch == 1: - self.trainer.fit_loop._results["training_step.tracking_metric"].value.load_state_dict( - self.metric_state - ) - - if batch_idx >= self.breaking_batch_idx and self.current_epoch == 1: - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.fit_loop._results["training_step.tracking_metric"].value - value_2 = self.trainer.fit_loop._results["training_step.tracking"].value - shift = 0 - if num_processes == 2: - shift = 3 if self.trainer.is_global_zero else -3 - expected = sum(range(batch_idx + 1)) + shift - assert expected == value == value_2 + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking_metric"].value + value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 else: - if self.trainer.current_epoch == 2: - return if batch_idx == self.breaking_batch_idx: # simulate failure mid epoch raise CustomException @@ -1305,7 +1292,7 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: - if self.trainer.current_epoch: + if self.has_reloaded: total = sum(range(5)) * num_processes metrics = self.trainer.fit_loop._results.metrics(on_step=False) assert self.trainer.fit_loop._results["training_step.tracking"].value == total @@ -1342,7 +1329,7 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload(tmpdir): result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": "ddp", "gpus": 1} ) From bd916659daf0219f0e2ef96bee51120d68ddfd23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jul 2021 11:17:14 +0000 Subject: [PATCH 04/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e71145aa29332..af2efdfad7449 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -258,7 +258,7 @@ def __getstate__(self, drop_value: bool = False) -> dict: d = {k: v for k, v in self.__dict__.items() if k not in skip} d["meta"] = d["meta"].__getstate__() d["_class"] = self.__class__.__name__ - d["_is_synced"] = False # don't consider the state as synced on reload + d["_is_synced"] = False # don't consider the state as synced on reload return d def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b8092f591abb3..d98c4dffb146c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1329,7 +1329,14 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload(tmpdir): result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": "ddp", "gpus": 1} + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 1, + } ) From ac9e9f1185ed0589a550c9b5c160b9f2afa49a3e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 30 Jul 2021 07:17:16 -0400 Subject: [PATCH 05/29] add comments --- pytorch_lightning/plugins/training_type/ddp.py | 1 + pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 86afeefd4399f..f03a439532b18 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -411,6 +411,7 @@ def _share_information_to_prevent_deadlock(self): self._share_pids() # remove `PL_DDP_SYNC_TMPDIR` from os.environ + # FIXME: Add better support for deadlock detection. Changed TMPDIR at on every trainer.{call_fn}. self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR") def _share_pids(self): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index aa50491c3ae5d..1e0ab08344c1b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -142,7 +142,7 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) - # reset state on non-rank 0 + # reset metrics states on non-rank 0 as the states have been synced on-saving. if not self.trainer.is_global_zero: for module in self.trainer.lightning_module.modules(): if isinstance(module, Metric): From 46078e91ecfa7bf9743de918a23d830bca26dbe9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Jul 2021 13:28:42 +0200 Subject: [PATCH 06/29] update changelog --- CHANGELOG.md | 4 +++- tests/deprecated_api/test_remove_1-5.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2806602291e7..16815f1416478 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641)) + + - @@ -28,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) - - Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 8dbe17e7a0a16..d6adb71de0dac 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -19,7 +19,6 @@ import pytest import torch -from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint From f1d50d221b1ae8b5d5ea43c8f780b55bfb127be4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Jul 2021 14:23:16 +0200 Subject: [PATCH 07/29] wip --- pytorch_lightning/loops/base.py | 42 +++++- .../connectors/checkpoint_connector.py | 34 ++++- .../connectors/logger_connector/result.py | 10 ++ tests/checkpointing/test_model_checkpoint.py | 130 ++++++++++++++++++ 4 files changed, 210 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8219e2ad02fc8..5ac8ed78d3289 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,8 +16,10 @@ from typing import Any, Dict, Optional from deprecate import void +from torchmetrics import Metric import pytorch_lightning as pl +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -177,21 +179,55 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[prefix + k] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, prefix + k + ".") + elif isinstance(v, ResultCollection): + # sync / unsync metrics + v.sync() + destination[prefix + k] = v.state_dict() + v.unsync() return destination - def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None: + def load_state_dict( + self, + state_dict: Dict, + prefix: str = "", + restart_progress: bool = True, + metrics: Optional[Dict[str, Metric]] = None, + ) -> None: """Loads the state of this loop and all its children.""" - self._load_from_state_dict(state_dict.copy(), prefix, restart_progress) + self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics) for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None + ) -> None: for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) + + elif isinstance(v, ResultCollection): + if isinstance(self.trainer, pl.Trainer) and getattr(self.trainer, "lightning_module", None) is not None: + metric_attributes = { + name: module + for name, module in self.trainer.lightning_module.named_modules() + if isinstance(module, Metric) + } + if metrics: + metric_attributes.update(metrics) + + # re-attach metrics references + v.load_state_dict( + state_dict[prefix + k], + metrics=metric_attributes, + sync_fn=self.trainer.training_type_plugin.reduce, + ) + + if not self.trainer.is_global_zero: + v.reset(metrics=False) + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 611946fd53dae..c8779e8db893d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,9 +15,10 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch +from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn @@ -35,6 +36,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint self._loaded_checkpoint = {} + self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -141,6 +143,12 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + # reset state on non-rank 0 + if not self.trainer.is_global_zero: + for module in self.trainer.lightning_module.modules(): + if isinstance(module, Metric): + module.reset() + def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """Restore only the model weights.""" checkpoint = self._loaded_checkpoint @@ -338,7 +346,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "epoch": current_epoch, "global_step": global_step, "pytorch-lightning_version": pl.__version__, - "state_dict": self.trainer.accelerator.lightning_module_state_dict(), + "state_dict": self._get_lightning_module_state_dict(), } if _fault_tolerant_enabled(): checkpoint["loops"] = self._get_loops_state_dict() @@ -440,7 +448,27 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: _checkpoint = self.dump_checkpoint(weights_only) self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) - def _get_loops_state_dict(self): + def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: + if _fault_tolerant_enabled(): + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + if self._persistent_metrics: + module.persistent(True) + if not module._is_synced: + module.sync() + self._persistent_metrics = True + + state_dict = self.trainer.accelerator.lightning_module_state_dict() + + if _fault_tolerant_enabled(): + for _, module in self.trainer.lightning_module.named_modules(): + if isinstance(module, Metric): + if module._is_synced: + module.unsync() + + return state_dict + + def _get_loops_state_dict(self) -> Dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 44774105cdb49..946c9f9eae59e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -604,6 +604,16 @@ def cpu(self) -> "ResultCollection": """Move all data to CPU.""" return self.to(device="cpu") + def sync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor and not result_metric._is_synced: + result_metric.sync() + + def unsync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor and result_metric._is_synced: + result_metric.unsync() + def __str__(self) -> str: # sample output: `ResultCollection(minimize=1.23, {})` minimize = f"minimize={self.minimize}, " if self.minimize is not None else "" diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0906ed3820705..f46d70fdf713b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -18,6 +18,7 @@ import re import time from argparse import Namespace +from contextlib import suppress from datetime import timedelta from logging import INFO from pathlib import Path @@ -31,12 +32,14 @@ import yaml from omegaconf import Container, OmegaConf from torch import optim +from torchmetrics import Metric import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -1220,3 +1223,130 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) + + +class DummyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, increment): + self.sum += increment + self.count += 1 + + def compute(self): + return self.sum // self.count + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" + + +def result_collection_reload(trainer_kwargs): + num_processes = trainer_kwargs.get("gpus", 1) + + class CustomException(Exception): + pass + + class ExtendedBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.has_reloaded = False + self.breaking_batch_idx = 3 + self.has_validated_sum = False + self.dummy_metric = DummyMetric() + self.dummy_metric_dynamic = DummyMetric() + + def training_step(self, batch, batch_idx): + assert len(batch) == 1 + if self.has_reloaded: + if batch_idx >= self.breaking_batch_idx: + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking_metric"].value + value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 + else: + if self.trainer.current_epoch == 2: + return + if batch_idx == self.breaking_batch_idx: + # simulate failure mid epoch + raise CustomException + + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking"].value + assert value == sum(range(batch_idx + 1)) + + value = self.trainer.fit_loop._results["training_step.tracking_2"] + assert value == sum(range(batch_idx + 1)) + + return super().training_step(batch, batch_idx) + + def on_epoch_end(self) -> None: + if self.trainer.current_epoch: + total = sum(range(5)) * num_processes + metrics = self.trainer.fit_loop._results.metrics(on_step=False) + assert self.trainer.fit_loop._results["training_step.tracking"].value == total + assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total + assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 + self.has_validated_sum = True + + model = ExtendedBoringModel() + + trainer = Trainer(**trainer_kwargs) + + with suppress(CustomException): + trainer.fit(model) + + checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt")) + trainer.save_checkpoint(checkpoint_path) + + trainer.accelerator.barrier() + + if trainer.is_global_zero: + checkpoint = torch.load(checkpoint_path) + assert checkpoint["state_dict"]["dummy_metric.sum"] == 3 * num_processes + + trainer_kwargs["resume_from_checkpoint"] = checkpoint_path + trainer_kwargs["max_epochs"] = 2 + + trainer = Trainer(**trainer_kwargs) + model.has_reloaded = True + trainer.fit(model) + assert model.has_validated_sum + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload(tmpdir): + result_collection_reload( + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + ) + + +@RunIf(min_gpus=2, special=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload_2_gpus(tmpdir): + result_collection_reload( + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 2, + } + ) From aeaeee62b16417d7931adb626034de22e8180e9b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Jul 2021 16:54:18 +0200 Subject: [PATCH 08/29] resolve some issues --- tests/checkpointing/test_model_checkpoint.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f46d70fdf713b..ddd6ed9f45df8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1255,12 +1255,21 @@ def __init__(self): self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMetric() - self.dummy_metric_dynamic = DummyMetric() def training_step(self, batch, batch_idx): assert len(batch) == 1 if self.has_reloaded: - if batch_idx >= self.breaking_batch_idx: + # hack as the state is being reset on epoch end + if batch_idx == 3 and self.current_epoch == 0: + self.metric_state = self.trainer.fit_loop._results[ + "training_step.tracking_metric" + ].value.state_dict() + if batch_idx == 0 and self.current_epoch == 1: + self.trainer.fit_loop._results["training_step.tracking_metric"].value.load_state_dict( + self.metric_state + ) + + if batch_idx >= self.breaking_batch_idx and self.current_epoch == 1: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) From a9368e95e3a68edcbefe08b1a984f68552718194 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 30 Jul 2021 07:15:06 -0400 Subject: [PATCH 09/29] add ResultCollection --- .../plugins/training_type/ddp.py | 2 +- .../connectors/checkpoint_connector.py | 22 ++++------ .../connectors/logger_connector/result.py | 3 +- pytorch_lightning/trainer/trainer.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 43 +++++++------------ 5 files changed, 28 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a44384e18edcd..86afeefd4399f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -411,7 +411,7 @@ def _share_information_to_prevent_deadlock(self): self._share_pids() # remove `PL_DDP_SYNC_TMPDIR` from os.environ - self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None) + self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR") def _share_pids(self): """ diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c8779e8db893d..aa50491c3ae5d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -36,7 +36,6 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint self._loaded_checkpoint = {} - self._persistent_metrics = False @property def hpc_resume_path(self) -> Optional[str]: @@ -220,7 +219,7 @@ def restore_loops(self) -> None: " consider using an end of epoch checkpoint." ) - state_dict = self._loaded_checkpoint.get("loops") + state_dict = self._loaded_checkpoint.pop("loops", None) if state_dict: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -450,21 +449,18 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: if _fault_tolerant_enabled(): - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - if self._persistent_metrics: - module.persistent(True) - if not module._is_synced: - module.sync() - self._persistent_metrics = True + metrics = [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] + for metric in metrics: + metric.persistent(True) + if not metric._is_synced: + metric.sync() state_dict = self.trainer.accelerator.lightning_module_state_dict() if _fault_tolerant_enabled(): - for _, module in self.trainer.lightning_module.named_modules(): - if isinstance(module, Metric): - if module._is_synced: - module.unsync() + for metric in metrics: + if metric._is_synced: + metric.unsync() return state_dict diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 946c9f9eae59e..ec7401168a97b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -251,13 +251,14 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - skip = ["update", "compute", "_update_signature"] + skip = ["update", "compute", "_update_signature", "_cache"] if not self.is_tensor and drop_value: # Avoid serializing ResultMetrics which are passed Metrics skip.append("value") d = {k: v for k, v in self.__dict__.items() if k not in skip} d["meta"] = d["meta"].__getstate__() d["_class"] = self.__class__.__name__ + d["_is_synced"] = False # don't consider the state as synced on reload return d def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e3a52a09d1bc8..8aa0237d05c6f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1323,7 +1323,7 @@ def _log_device_info(self) -> None: ) def _on_expection(self): - if not self.is_global_zero or not _fault_tolerant_enabled(): + if not _fault_tolerant_enabled(): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ddd6ed9f45df8..b8092f591abb3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1259,33 +1259,20 @@ def __init__(self): def training_step(self, batch, batch_idx): assert len(batch) == 1 if self.has_reloaded: - # hack as the state is being reset on epoch end - if batch_idx == 3 and self.current_epoch == 0: - self.metric_state = self.trainer.fit_loop._results[ - "training_step.tracking_metric" - ].value.state_dict() - if batch_idx == 0 and self.current_epoch == 1: - self.trainer.fit_loop._results["training_step.tracking_metric"].value.load_state_dict( - self.metric_state - ) - - if batch_idx >= self.breaking_batch_idx and self.current_epoch == 1: - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.fit_loop._results["training_step.tracking_metric"].value - value_2 = self.trainer.fit_loop._results["training_step.tracking"].value - shift = 0 - if num_processes == 2: - shift = 3 if self.trainer.is_global_zero else -3 - expected = sum(range(batch_idx + 1)) + shift - assert expected == value == value_2 + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking_metric"].value + value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 else: - if self.trainer.current_epoch == 2: - return if batch_idx == self.breaking_batch_idx: # simulate failure mid epoch raise CustomException @@ -1305,7 +1292,7 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: - if self.trainer.current_epoch: + if self.has_reloaded: total = sum(range(5)) * num_processes metrics = self.trainer.fit_loop._results.metrics(on_step=False) assert self.trainer.fit_loop._results["training_step.tracking"].value == total @@ -1342,7 +1329,7 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload(tmpdir): result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": "ddp", "gpus": 1} ) From c174917691cf413142e71f5e5812aeddd7e6906f Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 30 Jul 2021 07:17:16 -0400 Subject: [PATCH 10/29] add comments --- pytorch_lightning/plugins/training_type/ddp.py | 1 + pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 86afeefd4399f..f03a439532b18 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -411,6 +411,7 @@ def _share_information_to_prevent_deadlock(self): self._share_pids() # remove `PL_DDP_SYNC_TMPDIR` from os.environ + # FIXME: Add better support for deadlock detection. Changed TMPDIR at on every trainer.{call_fn}. self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR") def _share_pids(self): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index aa50491c3ae5d..1e0ab08344c1b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -142,7 +142,7 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) - # reset state on non-rank 0 + # reset metrics states on non-rank 0 as the states have been synced on-saving. if not self.trainer.is_global_zero: for module in self.trainer.lightning_module.modules(): if isinstance(module, Metric): From 3b7370a7969ab8e801a1e5a2afc24120689ac698 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jul 2021 11:17:14 +0000 Subject: [PATCH 11/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../trainer/connectors/logger_connector/result.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ec7401168a97b..d02d7d683a986 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -258,7 +258,7 @@ def __getstate__(self, drop_value: bool = False) -> dict: d = {k: v for k, v in self.__dict__.items() if k not in skip} d["meta"] = d["meta"].__getstate__() d["_class"] = self.__class__.__name__ - d["_is_synced"] = False # don't consider the state as synced on reload + d["_is_synced"] = False # don't consider the state as synced on reload return d def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b8092f591abb3..d98c4dffb146c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1329,7 +1329,14 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload(tmpdir): result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0, "accelerator": "ddp", "gpus": 1} + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 1, + } ) From a4128252673730a110c15c00631820ba4c51b918 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Jul 2021 13:28:42 +0200 Subject: [PATCH 12/29] update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eae5e5baae1cb..c4a97f66ce221 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641)) + + - @@ -28,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) - - Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) From 41a515675cfd5edec11f3258051e44c46616250d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Jul 2021 14:19:27 +0200 Subject: [PATCH 13/29] Reuse key definition --- pytorch_lightning/loops/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5ac8ed78d3289..dd29fd3933696 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -175,14 +175,15 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[prefix + "state_dict"] = self.on_save_checkpoint() for k, v in self.__dict__.items(): + key = prefix + k if isinstance(v, BaseProgress): - destination[prefix + k] = v.state_dict() + destination[key] = v.state_dict() elif isinstance(v, Loop): - v.state_dict(destination, prefix + k + ".") + v.state_dict(destination, key + ".") elif isinstance(v, ResultCollection): # sync / unsync metrics v.sync() - destination[prefix + k] = v.state_dict() + destination[key] = v.state_dict() v.unsync() return destination @@ -204,8 +205,9 @@ def _load_from_state_dict( self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None ) -> None: for k, v in self.__dict__.items(): + key = prefix + k if isinstance(v, BaseProgress): - v.load_state_dict(state_dict[prefix + k]) + v.load_state_dict(state_dict[key]) if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) @@ -221,9 +223,7 @@ def _load_from_state_dict( # re-attach metrics references v.load_state_dict( - state_dict[prefix + k], - metrics=metric_attributes, - sync_fn=self.trainer.training_type_plugin.reduce, + state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce ) if not self.trainer.is_global_zero: From df2beae5516e0512c17a0685fbb0415c73f65b7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Jul 2021 18:31:08 +0200 Subject: [PATCH 14/29] updates on comments --- pytorch_lightning/loops/base.py | 9 ++- .../plugins/training_type/ddp.py | 1 - .../connectors/checkpoint_connector.py | 8 +-- tests/checkpointing/test_model_checkpoint.py | 8 +++ tests/loops/test_loop_state_dict.py | 65 ++++++++++++++----- 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 5ac8ed78d3289..aada6d009df4c 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -210,7 +210,7 @@ def _load_from_state_dict( apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) elif isinstance(v, ResultCollection): - if isinstance(self.trainer, pl.Trainer) and getattr(self.trainer, "lightning_module", None) is not None: + if self.trainer is not None and getattr(self.trainer, "lightning_module", None) is not None: metric_attributes = { name: module for name, module in self.trainer.lightning_module.named_modules() @@ -219,7 +219,12 @@ def _load_from_state_dict( if metrics: metric_attributes.update(metrics) - # re-attach metrics references + # The `ResultCollection` objects has 2 types of metrics: tensor and torchmetrics.Metric. + # When creating a checkpoint, the Metric type are been dropped from the loop state_dict + # to serialize only pure Python primitives. + # However, their states are saved alongside the model state_dict. + # On reload, we need to re-attach the Metrics back the ResultCollection. + # The references are provided through ``metric_attributes`` dictionary. v.load_state_dict( state_dict[prefix + k], metrics=metric_attributes, diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index f03a439532b18..d842e7eb512c1 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -410,7 +410,6 @@ def register_plugins(cls, plugin_registry: Dict) -> None: def _share_information_to_prevent_deadlock(self): self._share_pids() - # remove `PL_DDP_SYNC_TMPDIR` from os.environ # FIXME: Add better support for deadlock detection. Changed TMPDIR at on every trainer.{call_fn}. self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR") diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1e0ab08344c1b..bc198308f1f42 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -142,7 +142,7 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) - # reset metrics states on non-rank 0 as the states have been synced on-saving. + # reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing. if not self.trainer.is_global_zero: for module in self.trainer.lightning_module.modules(): if isinstance(module, Metric): @@ -219,7 +219,7 @@ def restore_loops(self) -> None: " consider using an end of epoch checkpoint." ) - state_dict = self._loaded_checkpoint.pop("loops", None) + state_dict = self._loaded_checkpoint.get("loops") if state_dict: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -452,13 +452,13 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] for metric in metrics: metric.persistent(True) - if not metric._is_synced: - metric.sync() + metric.sync() state_dict = self.trainer.accelerator.lightning_module_state_dict() if _fault_tolerant_enabled(): for metric in metrics: + # on cpu, sync is a no-op and therefore `unsync` call would fail as the metrics is not synced. if metric._is_synced: metric.unsync() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d98c4dffb146c..b10ac5102d3de 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1328,6 +1328,14 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload(tmpdir): + result_collection_reload( + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + ) + + +@RunIf(min_gpus=1) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload_1_gpu_ddp(tmpdir): result_collection_reload( { "default_root_dir": tmpdir, diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 7f159b9e355ba..22d2be8c3a9b0 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +import torch from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer @@ -38,78 +39,106 @@ def test_loops_state_dict_structure(): expected = { "fit_loop": { "state_dict": {}, - "epoch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, "epoch_loop.scheduler_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.optim_progress": { "optimizer": { "step": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "zero_grad": { - "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, }, }, "optimizer_idx": 0, }, - "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_loop.val_loop._results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, + }, + "epoch_loop._results": { + "training": True, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, + }, + "epoch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, - "predict_loop": { + "validate_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "_results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, }, }, "test_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "_results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, }, }, - "validate_loop": { + "predict_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, } From d978b2c541953f9cef10784e865c0ccee6171be7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Jul 2021 18:50:09 +0200 Subject: [PATCH 15/29] update --- tests/checkpointing/test_model_checkpoint.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b10ac5102d3de..a85361dd4dcf5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1313,10 +1313,6 @@ def on_epoch_end(self) -> None: trainer.accelerator.barrier() - if trainer.is_global_zero: - checkpoint = torch.load(checkpoint_path) - assert checkpoint["state_dict"]["dummy_metric.sum"] == 3 * num_processes - trainer_kwargs["resume_from_checkpoint"] = checkpoint_path trainer_kwargs["max_epochs"] = 2 From a493d1fa674b9e4e333868ffbe15748a4621fee6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Jul 2021 18:52:25 +0200 Subject: [PATCH 16/29] Indentation and comments --- pytorch_lightning/loops/base.py | 48 +++++++++---------- .../connectors/checkpoint_connector.py | 25 +++++----- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8f2db88b060b4..ee5c3a1b708f1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -211,30 +211,30 @@ def _load_from_state_dict( if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) - elif isinstance(v, ResultCollection): - if self.trainer is not None and getattr(self.trainer, "lightning_module", None) is not None: - metric_attributes = { - name: module - for name, module in self.trainer.lightning_module.named_modules() - if isinstance(module, Metric) - } - if metrics: - metric_attributes.update(metrics) - - # The `ResultCollection` objects has 2 types of metrics: tensor and torchmetrics.Metric. - # When creating a checkpoint, the Metric type are been dropped from the loop state_dict - # to serialize only pure Python primitives. - # However, their states are saved alongside the model state_dict. - # On reload, we need to re-attach the Metrics back the ResultCollection. - # The references are provided through ``metric_attributes`` dictionary. - v.load_state_dict( - state_dict[prefix + k], - metrics=metric_attributes, - sync_fn=self.trainer.training_type_plugin.reduce, - ) - - if not self.trainer.is_global_zero: - v.reset(metrics=False) + elif ( + isinstance(v, ResultCollection) + and self.trainer is not None + and getattr(self.trainer, "lightning_module", None) is not None + ): + metric_attributes = { + name: module + for name, module in self.trainer.lightning_module.named_modules() + if isinstance(module, Metric) + } + if metrics: + metric_attributes.update(metrics) + + # The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`. + # When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only + # Python primitives. However, their states are saved with the model's `state_dict`. + # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. + # The references are provided through the `metric_attributes` dictionary. + v.load_state_dict( + state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce + ) + + if not self.trainer.is_global_zero: + v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 0d9af8c839660..b10d949db0758 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -219,7 +219,7 @@ def restore_loops(self) -> None: " consider using an end of epoch checkpoint." ) - state_dict = self._loaded_checkpoint.pop("loops", None) + state_dict = self._loaded_checkpoint.get("loops") if state_dict: self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) @@ -448,19 +448,22 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: - if _fault_tolerant_enabled(): - metrics = [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] - for metric in metrics: - metric.persistent(True) - metric.sync() + metrics = ( + [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] + if _fault_tolerant_enabled() + else [] + ) + + for metric in metrics: + metric.persistent(True) + metric.sync() state_dict = self.trainer.accelerator.lightning_module_state_dict() - if _fault_tolerant_enabled(): - for metric in metrics: - # on cpu, sync is a no-op and therefore `unsync` call would fail as the metrics is not synced. - if metric._is_synced: - metric.unsync() + for metric in metrics: + # sync can be a no-op (e.g. on cpu) so `unsync` would raise an user error exception if we don't check + if metric._is_synced: + metric.unsync() return state_dict From 0aa5659af023e9f55d7d1257a8057c08a1962669 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Jul 2021 18:55:49 +0200 Subject: [PATCH 17/29] apply comments --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index d02d7d683a986..77f5426a2d700 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -607,7 +607,7 @@ def cpu(self) -> "ResultCollection": def sync(self) -> None: for result_metric in self.result_metrics: - if result_metric.is_tensor and not result_metric._is_synced: + if result_metric.is_tensor: result_metric.sync() def unsync(self) -> None: From 7f00a8ba11fac6b23fe43f51bf85711fe41e6a9f Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 1 Aug 2021 12:20:57 +0200 Subject: [PATCH 18/29] update on comments --- tests/checkpointing/test_model_checkpoint.py | 137 ------------------- tests/core/test_metric_result_integration.py | 137 +++++++++++++++++++ 2 files changed, 137 insertions(+), 137 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a85361dd4dcf5..0906ed3820705 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -18,7 +18,6 @@ import re import time from argparse import Namespace -from contextlib import suppress from datetime import timedelta from logging import INFO from pathlib import Path @@ -32,14 +31,12 @@ import yaml from omegaconf import Container, OmegaConf from torch import optim -from torchmetrics import Metric import pytorch_lightning as pl import tests.helpers.utils as tutils from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -1223,137 +1220,3 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) - - -class DummyMetric(Metric): - def __init__(self): - super().__init__() - self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) - self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) - - def update(self, increment): - self.sum += increment - self.count += 1 - - def compute(self): - return self.sum // self.count - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" - - -def result_collection_reload(trainer_kwargs): - num_processes = trainer_kwargs.get("gpus", 1) - - class CustomException(Exception): - pass - - class ExtendedBoringModel(BoringModel): - def __init__(self): - super().__init__() - self.has_reloaded = False - self.breaking_batch_idx = 3 - self.has_validated_sum = False - self.dummy_metric = DummyMetric() - - def training_step(self, batch, batch_idx): - assert len(batch) == 1 - if self.has_reloaded: - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.fit_loop._results["training_step.tracking_metric"].value - value_2 = self.trainer.fit_loop._results["training_step.tracking"].value - shift = 0 - if num_processes == 2: - shift = 3 if self.trainer.is_global_zero else -3 - expected = sum(range(batch_idx + 1)) + shift - assert expected == value == value_2 - else: - if batch_idx == self.breaking_batch_idx: - # simulate failure mid epoch - raise CustomException - - self.log("tracking", batch_idx, on_step=True, on_epoch=True) - self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) - - self.dummy_metric(batch_idx) - self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - - value = self.trainer.fit_loop._results["training_step.tracking"].value - assert value == sum(range(batch_idx + 1)) - - value = self.trainer.fit_loop._results["training_step.tracking_2"] - assert value == sum(range(batch_idx + 1)) - - return super().training_step(batch, batch_idx) - - def on_epoch_end(self) -> None: - if self.has_reloaded: - total = sum(range(5)) * num_processes - metrics = self.trainer.fit_loop._results.metrics(on_step=False) - assert self.trainer.fit_loop._results["training_step.tracking"].value == total - assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 - assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total - assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 - self.has_validated_sum = True - - model = ExtendedBoringModel() - - trainer = Trainer(**trainer_kwargs) - - with suppress(CustomException): - trainer.fit(model) - - checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt")) - trainer.save_checkpoint(checkpoint_path) - - trainer.accelerator.barrier() - - trainer_kwargs["resume_from_checkpoint"] = checkpoint_path - trainer_kwargs["max_epochs"] = 2 - - trainer = Trainer(**trainer_kwargs) - model.has_reloaded = True - trainer.fit(model) - assert model.has_validated_sum - - -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_result_collection_reload(tmpdir): - result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} - ) - - -@RunIf(min_gpus=1) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_result_collection_reload_1_gpu_ddp(tmpdir): - result_collection_reload( - { - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - "accelerator": "ddp", - "gpus": 1, - } - ) - - -@RunIf(min_gpus=2, special=True) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -def test_result_collection_reload_2_gpus(tmpdir): - result_collection_reload( - { - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - "accelerator": "ddp", - "gpus": 2, - } - ) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index fa2f9ccdf7c50..806266aa6e680 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle +from contextlib import suppress from copy import deepcopy +from unittest import mock import pytest import torch @@ -353,3 +356,137 @@ def test_result_collection_extra_reference(): """Unit-test to check that the `extra` dict reference is properly set.""" rc = ResultCollection(True) assert rc.extra is rc["_extra"] + + +class DummyMeanMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, increment): + self.sum += increment + self.count += 1 + + def compute(self): + return self.sum // self.count + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" + + +def result_collection_reload(trainer_kwargs): + num_processes = trainer_kwargs.get("gpus", 1) + + class CustomException(Exception): + pass + + class ExtendedBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.has_reloaded = False + self.breaking_batch_idx = 3 + self.has_validated_sum = False + self.dummy_metric = DummyMeanMetric() + + def training_step(self, batch, batch_idx): + assert len(batch) == 1 + if self.has_reloaded: + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking_metric"].value + value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 + else: + if batch_idx == self.breaking_batch_idx: + # simulate failure mid epoch + raise CustomException + + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.trainer.fit_loop._results["training_step.tracking"].value + assert value == sum(range(batch_idx + 1)) + + value = self.trainer.fit_loop._results["training_step.tracking_2"] + assert value == sum(range(batch_idx + 1)) + + return super().training_step(batch, batch_idx) + + def on_epoch_end(self) -> None: + if self.has_reloaded: + total = sum(range(5)) * num_processes + metrics = self.trainer.fit_loop._results.metrics(on_step=False) + assert self.trainer.fit_loop._results["training_step.tracking"].value == total + assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total + assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 + self.has_validated_sum = True + + model = ExtendedBoringModel() + + trainer = Trainer(**trainer_kwargs) + + with suppress(CustomException): + trainer.fit(model) + + checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt")) + trainer.save_checkpoint(checkpoint_path) + + trainer.accelerator.barrier() + + trainer_kwargs["resume_from_checkpoint"] = checkpoint_path + trainer_kwargs["max_epochs"] = 2 + + trainer = Trainer(**trainer_kwargs) + model.has_reloaded = True + trainer.fit(model) + assert model.has_validated_sum + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload(tmpdir): + result_collection_reload( + {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + ) + + +@RunIf(min_gpus=1) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload_1_gpu_ddp(tmpdir): + result_collection_reload( + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 1, + } + ) + + +@RunIf(min_gpus=2, special=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_result_collection_reload_2_gpus(tmpdir): + result_collection_reload( + { + "default_root_dir": tmpdir, + "max_epochs": 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accelerator": "ddp", + "gpus": 2, + } + ) From 768e4ea01a5228c82e39f696c3fbe43b050ab0f9 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Sun, 1 Aug 2021 06:38:16 -0400 Subject: [PATCH 19/29] resolve tests --- tests/loops/test_loops.py | 4 +++- tests/trainer/test_trainer.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 19e2c7628ce25..d136444baf2f4 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -380,7 +380,7 @@ def configure_optimizers_multiple(self): ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") checkpoint = torch.load(ckpt_path) - + optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress @@ -497,6 +497,8 @@ def configure_optimizers_multiple(self): "epoch_loop.val_loop.dataloader_progress": ANY, "epoch_loop.val_loop.epoch_loop.state_dict": ANY, "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + 'epoch_loop.val_loop._results': ANY, + 'epoch_loop._results': ANY, } assert checkpoint["loops"]["fit_loop"] == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 86ca0d1fc5618..578fcd4b7c1e3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -20,7 +20,7 @@ from copy import deepcopy from pathlib import Path from unittest.mock import ANY, call, patch - +import gc import cloudpickle import pytest import torch @@ -1880,6 +1880,7 @@ def on_epoch_start(self, trainer, *_): ) trainer = Trainer(**trainer_kwargs) trainer.fit(model) + gc.collect() assert trainer.training_type_plugin.model is model assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") @@ -1892,6 +1893,8 @@ def on_epoch_start(self, trainer, *_): trainer_2 = Trainer(**trainer_kwargs) trainer_2.fit(model) + gc.collect() + memory_3 = torch.cuda.memory_allocated(0) assert initial == memory_1 == memory_3 From 741627219b1a209f94fd883b841072c9f5407368 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Sun, 1 Aug 2021 06:40:41 -0400 Subject: [PATCH 20/29] typo --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ccff010ea2a91..b8e45e0a4c689 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1881,7 +1881,6 @@ def on_epoch_start(self, trainer, *_): ) trainer = Trainer(**trainer_kwargs) trainer.fit(model) - gc.collect() assert trainer.training_type_plugin.model is model assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") From 07888e2d3bfbba0facfcb789c42f6571dc0fc833 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Aug 2021 10:41:04 +0000 Subject: [PATCH 21/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_loops.py | 6 +++--- tests/trainer/test_trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d136444baf2f4..45c6453688939 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -380,7 +380,7 @@ def configure_optimizers_multiple(self): ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") checkpoint = torch.load(ckpt_path) - + optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress @@ -497,8 +497,8 @@ def configure_optimizers_multiple(self): "epoch_loop.val_loop.dataloader_progress": ANY, "epoch_loop.val_loop.epoch_loop.state_dict": ANY, "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, - 'epoch_loop.val_loop._results': ANY, - 'epoch_loop._results': ANY, + "epoch_loop.val_loop._results": ANY, + "epoch_loop._results": ANY, } assert checkpoint["loops"]["fit_loop"] == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b8e45e0a4c689..a5ac053395515 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -21,7 +21,7 @@ from copy import deepcopy from pathlib import Path from unittest.mock import ANY, call, patch -import gc + import cloudpickle import pytest import torch From 7349ad157aa449c7b4aa837e1aa99e8881221885 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 1 Aug 2021 18:55:33 +0200 Subject: [PATCH 22/29] update --- tests/core/test_metric_result_integration.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 806266aa6e680..b71609dab72bb 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,6 +27,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -456,6 +457,9 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" +) def test_result_collection_reload(tmpdir): result_collection_reload( {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} @@ -464,6 +468,9 @@ def test_result_collection_reload(tmpdir): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" +) def test_result_collection_reload_1_gpu_ddp(tmpdir): result_collection_reload( { @@ -479,6 +486,9 @@ def test_result_collection_reload_1_gpu_ddp(tmpdir): @RunIf(min_gpus=2, special=True) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" +) def test_result_collection_reload_2_gpus(tmpdir): result_collection_reload( { From e451594a9023ec67673d17c9eac5ddb7afa689c0 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 2 Aug 2021 09:41:08 +0200 Subject: [PATCH 23/29] Update pytorch_lightning/trainer/connectors/checkpoint_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6b787528501dc..c096f0a609378 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -464,7 +464,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: state_dict = self.trainer.accelerator.lightning_module_state_dict() for metric in metrics: - # sync can be a no-op (e.g. on cpu) so `unsync` would raise an user error exception if we don't check + # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check if metric._is_synced: metric.unsync() From ec613745336d416d6d54d1b10ebe673f846280ef Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 2 Aug 2021 13:22:04 +0200 Subject: [PATCH 24/29] Refactor test --- tests/core/test_metric_result_integration.py | 65 +++++--------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index b71609dab72bb..c33475ced593e 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,7 +27,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -376,8 +376,11 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" -def result_collection_reload(trainer_kwargs): - num_processes = trainer_kwargs.get("gpus", 1) +def result_collection_reload(**kwargs): + if not _fault_tolerant_enabled(): + pytest.skip("Fault tolerant not available") + + num_processes = kwargs.get("gpus", 1) class CustomException(Exception): pass @@ -385,14 +388,12 @@ class CustomException(Exception): class ExtendedBoringModel(BoringModel): def __init__(self): super().__init__() - self.has_reloaded = False self.breaking_batch_idx = 3 self.has_validated_sum = False self.dummy_metric = DummyMeanMetric() def training_step(self, batch, batch_idx): - assert len(batch) == 1 - if self.has_reloaded: + if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) @@ -426,7 +427,7 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) def on_epoch_end(self) -> None: - if self.has_reloaded: + if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes metrics = self.trainer.fit_loop._results.metrics(on_step=False) assert self.trainer.fit_loop._results["training_step.tracking"].value == total @@ -436,67 +437,35 @@ def on_epoch_end(self) -> None: self.has_validated_sum = True model = ExtendedBoringModel() - + trainer_kwargs = {"max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + trainer_kwargs.update(kwargs) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): trainer.fit(model) + assert not model.has_validated_sum - checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt")) - trainer.save_checkpoint(checkpoint_path) - - trainer.accelerator.barrier() - - trainer_kwargs["resume_from_checkpoint"] = checkpoint_path - trainer_kwargs["max_epochs"] = 2 + tmpdir = trainer_kwargs["default_root_dir"] + ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + trainer_kwargs["resume_from_checkpoint"] = ckpt_path trainer = Trainer(**trainer_kwargs) - model.has_reloaded = True trainer.fit(model) assert model.has_validated_sum @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" -) def test_result_collection_reload(tmpdir): - result_collection_reload( - {"default_root_dir": tmpdir, "max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} - ) + result_collection_reload(default_root_dir=tmpdir) @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" -) def test_result_collection_reload_1_gpu_ddp(tmpdir): - result_collection_reload( - { - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - "accelerator": "ddp", - "gpus": 1, - } - ) + result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=1) @RunIf(min_gpus=2, special=True) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_7, reason="fault tolerant training is not support for PyTorch 1.6 and below" -) def test_result_collection_reload_2_gpus(tmpdir): - result_collection_reload( - { - "default_root_dir": tmpdir, - "max_epochs": 1, - "limit_train_batches": 5, - "limit_val_batches": 0, - "accelerator": "ddp", - "gpus": 2, - } - ) + result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=2) From d7c72dec08eca4698fb17cab666a8f3ce942e08d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 2 Aug 2021 14:26:31 +0200 Subject: [PATCH 25/29] add comments --- requirements.txt | 2 +- tests/core/test_metric_result_integration.py | 33 +++++++++++++++----- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index e886293b122b5..3141829cde5f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.4.0 +torchmetrics>=0.4.1 pyDeprecate==0.3.1 packaging>=17.0 typing-extensions # TypedDict support for python<3.8 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index c33475ced593e..93ebec685b98c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -377,6 +377,12 @@ def __repr__(self) -> str: def result_collection_reload(**kwargs): + + """ + This test is going to validate ResultCollection is properly being reload + and final accumulation with Fault Tolerant Training is correct. + """ + if not _fault_tolerant_enabled(): pytest.skip("Fault tolerant not available") @@ -392,7 +398,17 @@ def __init__(self): self.has_validated_sum = False self.dummy_metric = DummyMeanMetric() + @property + def results(self): + return self.trainer.fit_loop._results + def training_step(self, batch, batch_idx): + + # In the training step, we will accumulate metrics using batch_idx from 0 to 4 + # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` + # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. + # However, below we will simulate a failure on ``batch_idx=3``. + if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) @@ -400,8 +416,11 @@ def training_step(self, batch, batch_idx): self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - value = self.trainer.fit_loop._results["training_step.tracking_metric"].value - value_2 = self.trainer.fit_loop._results["training_step.tracking"].value + value = self.results["training_step.tracking_metric"].value + value_2 = self.results["training_step.tracking"].value + + # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. + # The shift indicates we failed while the state was ``shift={sign(is_global_zero > 0)} * sum(range(3))`` shift = 0 if num_processes == 2: shift = 3 if self.trainer.is_global_zero else -3 @@ -418,10 +437,10 @@ def training_step(self, batch, batch_idx): self.dummy_metric(batch_idx) self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) - value = self.trainer.fit_loop._results["training_step.tracking"].value + value = self.results["training_step.tracking"].value assert value == sum(range(batch_idx + 1)) - value = self.trainer.fit_loop._results["training_step.tracking_2"] + value = self.results["training_step.tracking_2"] assert value == sum(range(batch_idx + 1)) return super().training_step(batch, batch_idx) @@ -429,10 +448,10 @@ def training_step(self, batch, batch_idx): def on_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * num_processes - metrics = self.trainer.fit_loop._results.metrics(on_step=False) - assert self.trainer.fit_loop._results["training_step.tracking"].value == total + metrics = self.results.metrics(on_step=False) + assert self.results["training_step.tracking"].value == total assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 - assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total + assert self.results["training_step.tracking_2"].value == total assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 self.has_validated_sum = True From 5d1849692e25d2b28bcbb60d3f3906654bac8079 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 2 Aug 2021 14:50:25 +0200 Subject: [PATCH 26/29] nit --- tests/core/test_metric_result_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 93ebec685b98c..cd68239765175 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -407,7 +407,7 @@ def training_step(self, batch, batch_idx): # In the training step, we will accumulate metrics using batch_idx from 0 to 4 # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. - # However, below we will simulate a failure on ``batch_idx=3``. + # However, below we will simulate a failure on `batch_idx=3`. if self.trainer.fit_loop.restarting: self.log("tracking", batch_idx, on_step=True, on_epoch=True) @@ -420,7 +420,7 @@ def training_step(self, batch, batch_idx): value_2 = self.results["training_step.tracking"].value # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. - # The shift indicates we failed while the state was ``shift={sign(is_global_zero > 0)} * sum(range(3))`` + # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]` shift = 0 if num_processes == 2: shift = 3 if self.trainer.is_global_zero else -3 From a579b996a0ac1be5759ea067c078ea809f493265 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 2 Aug 2021 10:19:50 -0400 Subject: [PATCH 27/29] update --- tests/core/test_metric_result_integration.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index cd68239765175..ba85dbc35797c 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,6 +27,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -464,7 +465,7 @@ def on_epoch_end(self) -> None: trainer.fit(model) assert not model.has_validated_sum - tmpdir = trainer_kwargs["default_root_dir"] + tmpdir = trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) if num_processes >=2 else trainer_kwargs["default_root_dir"] ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") trainer_kwargs["resume_from_checkpoint"] = ckpt_path @@ -474,17 +475,20 @@ def on_epoch_end(self) -> None: @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") def test_result_collection_reload(tmpdir): result_collection_reload(default_root_dir=tmpdir) @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") def test_result_collection_reload_1_gpu_ddp(tmpdir): result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=1) @RunIf(min_gpus=2, special=True) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") def test_result_collection_reload_2_gpus(tmpdir): result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=2) From 0e377ddf3cac7d92d2070af245876a202b2d57ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Aug 2021 14:20:51 +0000 Subject: [PATCH 28/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_metric_result_integration.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ba85dbc35797c..5bf6ece1fe1de 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -27,8 +27,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 -from pytorch_lightning.utilities.imports import _fault_tolerant_enabled +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -465,7 +464,11 @@ def on_epoch_end(self) -> None: trainer.fit(model) assert not model.has_validated_sum - tmpdir = trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) if num_processes >=2 else trainer_kwargs["default_root_dir"] + tmpdir = ( + trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) + if num_processes >= 2 + else trainer_kwargs["default_root_dir"] + ) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") trainer_kwargs["resume_from_checkpoint"] = ckpt_path From 923c74b9445e5bc472f208b727a4803bf70629a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Aug 2021 19:56:54 +0000 Subject: [PATCH 29/29] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_metric_result_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index ed39ae8abd61f..0c90dee2e5639 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -557,4 +557,3 @@ def on_train_epoch_end(self) -> None: trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=0) trainer.fit(model) -