diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e6b66001862d..da5021cb1f8dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641)) + + - @@ -28,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) - - Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8219e2ad02fc8..ee5c3a1b708f1 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -16,8 +16,10 @@ from typing import Any, Dict, Optional from deprecate import void +from torchmetrics import Metric import pytorch_lightning as pl +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -173,25 +175,66 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = destination[prefix + "state_dict"] = self.on_save_checkpoint() for k, v in self.__dict__.items(): + key = prefix + k if isinstance(v, BaseProgress): - destination[prefix + k] = v.state_dict() + destination[key] = v.state_dict() elif isinstance(v, Loop): - v.state_dict(destination, prefix + k + ".") + v.state_dict(destination, key + ".") + elif isinstance(v, ResultCollection): + # sync / unsync metrics + v.sync() + destination[key] = v.state_dict() + v.unsync() return destination - def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None: + def load_state_dict( + self, + state_dict: Dict, + prefix: str = "", + restart_progress: bool = True, + metrics: Optional[Dict[str, Metric]] = None, + ) -> None: """Loads the state of this loop and all its children.""" - self._load_from_state_dict(state_dict.copy(), prefix, restart_progress) + self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics) for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None + ) -> None: for k, v in self.__dict__.items(): + key = prefix + k if isinstance(v, BaseProgress): - v.load_state_dict(state_dict[prefix + k]) + v.load_state_dict(state_dict[key]) if restart_progress: apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) + + elif ( + isinstance(v, ResultCollection) + and self.trainer is not None + and getattr(self.trainer, "lightning_module", None) is not None + ): + metric_attributes = { + name: module + for name, module in self.trainer.lightning_module.named_modules() + if isinstance(module, Metric) + } + if metrics: + metric_attributes.update(metrics) + + # The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`. + # When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only + # Python primitives. However, their states are saved with the model's `state_dict`. + # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. + # The references are provided through the `metric_attributes` dictionary. + v.load_state_dict( + state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce + ) + + if not self.trainer.is_global_zero: + v.reset(metrics=False) + self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3bb2dbd3ea61e..c096f0a609378 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,9 +15,10 @@ import os import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch +from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn @@ -141,6 +142,12 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + # reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing. + if not self.trainer.is_global_zero: + for module in self.trainer.lightning_module.modules(): + if isinstance(module, Metric): + module.reset() + def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """Restore only the model weights.""" checkpoint = self._loaded_checkpoint @@ -341,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: "epoch": current_epoch, "global_step": global_step, "pytorch-lightning_version": pl.__version__, - "state_dict": self.trainer.accelerator.lightning_module_state_dict(), + "state_dict": self._get_lightning_module_state_dict(), } if _fault_tolerant_enabled(): checkpoint["loops"] = self._get_loops_state_dict() @@ -443,7 +450,27 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: _checkpoint = self.dump_checkpoint(weights_only) self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) - def _get_loops_state_dict(self): + def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: + metrics = ( + [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)] + if _fault_tolerant_enabled() + else [] + ) + + for metric in metrics: + metric.persistent(True) + metric.sync() + + state_dict = self.trainer.accelerator.lightning_module_state_dict() + + for metric in metrics: + # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check + if metric._is_synced: + metric.unsync() + + return state_dict + + def _get_loops_state_dict(self) -> Dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 28dda272485af..2b2e4613f2298 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -251,13 +251,14 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({state})" def __getstate__(self, drop_value: bool = False) -> dict: - skip = ["update", "compute", "_update_signature"] + skip = ["update", "compute", "_update_signature", "_cache"] if not self.is_tensor and drop_value: # Avoid serializing ResultMetrics which are passed Metrics skip.append("value") d = {k: v for k, v in self.__dict__.items() if k not in skip} d["meta"] = d["meta"].__getstate__() d["_class"] = self.__class__.__name__ + d["_is_synced"] = False # don't consider the state as synced on reload return d def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: @@ -604,6 +605,16 @@ def cpu(self) -> "ResultCollection": """Move all data to CPU.""" return self.to(device="cpu") + def sync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor: + result_metric.sync() + + def unsync(self) -> None: + for result_metric in self.result_metrics: + if result_metric.is_tensor and result_metric._is_synced: + result_metric.unsync() + def __str__(self) -> str: # sample output: `ResultCollection(minimize=1.23, {})` minimize = f"minimize={self.minimize}, " if self.minimize is not None else "" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5b1fd7cdd6fb5..d2c6fab7ba559 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1321,7 +1321,7 @@ def _log_device_info(self) -> None: ) def _on_expection(self): - if not self.is_global_zero or not _fault_tolerant_enabled(): + if not _fault_tolerant_enabled(): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") diff --git a/requirements.txt b/requirements.txt index e886293b122b5..3141829cde5f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.4.0 +torchmetrics>=0.4.1 pyDeprecate==0.3.1 packaging>=17.0 typing-extensions # TypedDict support for python<3.8 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 4cd88065493d9..0c90dee2e5639 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle +from contextlib import suppress from copy import deepcopy +from unittest import mock import pytest import torch @@ -25,6 +28,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -356,6 +360,144 @@ def test_result_collection_extra_reference(): assert rc.extra is rc["_extra"] +class DummyMeanMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum) + self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum) + + def update(self, increment): + self.sum += increment + self.count += 1 + + def compute(self): + return self.sum // self.count + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})" + + +def result_collection_reload(**kwargs): + + """ + This test is going to validate ResultCollection is properly being reload + and final accumulation with Fault Tolerant Training is correct. + """ + + if not _fault_tolerant_enabled(): + pytest.skip("Fault tolerant not available") + + num_processes = kwargs.get("gpus", 1) + + class CustomException(Exception): + pass + + class ExtendedBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.breaking_batch_idx = 3 + self.has_validated_sum = False + self.dummy_metric = DummyMeanMetric() + + @property + def results(self): + return self.trainer.fit_loop._results + + def training_step(self, batch, batch_idx): + + # In the training step, we will accumulate metrics using batch_idx from 0 to 4 + # Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size` + # Therefore, compute on `epoch_end` should provide 2 as `10 / 5`. + # However, below we will simulate a failure on `batch_idx=3`. + + if self.trainer.fit_loop.restarting: + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.results["training_step.tracking_metric"].value + value_2 = self.results["training_step.tracking"].value + + # On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks. + # The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]` + shift = 0 + if num_processes == 2: + shift = 3 if self.trainer.is_global_zero else -3 + expected = sum(range(batch_idx + 1)) + shift + assert expected == value == value_2 + else: + if batch_idx == self.breaking_batch_idx: + # simulate failure mid epoch + raise CustomException + + self.log("tracking", batch_idx, on_step=True, on_epoch=True) + self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True) + + self.dummy_metric(batch_idx) + self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True) + + value = self.results["training_step.tracking"].value + assert value == sum(range(batch_idx + 1)) + + value = self.results["training_step.tracking_2"] + assert value == sum(range(batch_idx + 1)) + + return super().training_step(batch, batch_idx) + + def on_epoch_end(self) -> None: + if self.trainer.fit_loop.restarting: + total = sum(range(5)) * num_processes + metrics = self.results.metrics(on_step=False) + assert self.results["training_step.tracking"].value == total + assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2 + assert self.results["training_step.tracking_2"].value == total + assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2 + self.has_validated_sum = True + + model = ExtendedBoringModel() + trainer_kwargs = {"max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0} + trainer_kwargs.update(kwargs) + trainer = Trainer(**trainer_kwargs) + + with suppress(CustomException): + trainer.fit(model) + assert not model.has_validated_sum + + tmpdir = ( + trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) + if num_processes >= 2 + else trainer_kwargs["default_root_dir"] + ) + ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + trainer_kwargs["resume_from_checkpoint"] = ckpt_path + + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + assert model.has_validated_sum + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") +def test_result_collection_reload(tmpdir): + result_collection_reload(default_root_dir=tmpdir) + + +@RunIf(min_gpus=1) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") +def test_result_collection_reload_1_gpu_ddp(tmpdir): + result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=1) + + +@RunIf(min_gpus=2, special=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7") +def test_result_collection_reload_2_gpus(tmpdir): + result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=2) + + def test_metric_collections(tmpdir): """This test ensures the metric attribute is properly found even with complex nested metric structure""" diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 7f159b9e355ba..22d2be8c3a9b0 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +import torch from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer @@ -38,78 +39,106 @@ def test_loops_state_dict_structure(): expected = { "fit_loop": { "state_dict": {}, - "epoch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, "epoch_loop.scheduler_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.optim_progress": { "optimizer": { "step": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "zero_grad": { - "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, }, }, "optimizer_idx": 0, }, - "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_loop.val_loop._results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, + }, + "epoch_loop._results": { + "training": True, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, + }, + "epoch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, - "predict_loop": { + "validate_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "_results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, }, }, "test_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "_results": { + "training": False, + "_minimize": None, + "_batch_size": torch.tensor(1), + "device": None, + "items": {}, }, }, - "validate_loop": { + "predict_loop": { "state_dict": {}, "dataloader_progress": { - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, } diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 19e2c7628ce25..45c6453688939 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -497,6 +497,8 @@ def configure_optimizers_multiple(self): "epoch_loop.val_loop.dataloader_progress": ANY, "epoch_loop.val_loop.epoch_loop.state_dict": ANY, "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + "epoch_loop.val_loop._results": ANY, + "epoch_loop._results": ANY, } assert checkpoint["loops"]["fit_loop"] == expected