Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save the ResultCollection in the loops state dict #8641

Merged
merged 35 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f2f6858
wip
tchaton Jul 29, 2021
4307a05
resolve some issues
tchaton Jul 29, 2021
e63c560
add ResultCollection
tchaton Jul 30, 2021
bd91665
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2021
ac9e9f1
add comments
tchaton Jul 30, 2021
6e609f1
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Jul 30, 2021
46078e9
update changelog
tchaton Jul 30, 2021
f1d50d2
wip
tchaton Jul 29, 2021
aeaeee6
resolve some issues
tchaton Jul 29, 2021
a9368e9
add ResultCollection
tchaton Jul 30, 2021
c174917
add comments
tchaton Jul 30, 2021
3b7370a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2021
a412825
update changelog
tchaton Jul 30, 2021
41a5156
Reuse key definition
carmocca Jul 30, 2021
df2beae
updates on comments
tchaton Jul 30, 2021
8386cfc
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Jul 30, 2021
d978b2c
update
tchaton Jul 30, 2021
a493d1f
Indentation and comments
carmocca Jul 30, 2021
0aa5659
apply comments
tchaton Jul 30, 2021
7f00a8b
update on comments
tchaton Aug 1, 2021
768e4ea
resolve tests
tchaton Aug 1, 2021
dfbb051
Merge branch 'master' into add_support_for_logging
tchaton Aug 1, 2021
7416272
typo
tchaton Aug 1, 2021
07888e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2021
7349ad1
update
tchaton Aug 1, 2021
b99905b
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Aug 1, 2021
e451594
Update pytorch_lightning/trainer/connectors/checkpoint_connector.py
tchaton Aug 2, 2021
ec61374
Refactor test
carmocca Aug 2, 2021
d7c72de
add comments
tchaton Aug 2, 2021
5d18496
nit
carmocca Aug 2, 2021
a579b99
update
tchaton Aug 2, 2021
0e377dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
5d36135
Merge branch 'master' into add_support_for_logging
tchaton Aug 2, 2021
b7ce5ad
Merge branch 'master' into add_support_for_logging
tchaton Aug 2, 2021
923c74b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641))


-


Expand All @@ -28,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))



- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand Down
55 changes: 49 additions & 6 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from typing import Any, Dict, Optional

from deprecate import void
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -173,25 +175,66 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] =
destination[prefix + "state_dict"] = self.on_save_checkpoint()

for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
destination[prefix + k] = v.state_dict()
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, prefix + k + ".")
v.state_dict(destination, key + ".")
elif isinstance(v, ResultCollection):
# sync / unsync metrics
v.sync()
destination[key] = v.state_dict()
v.unsync()

return destination

def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None:
def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None:
def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())

elif (
isinstance(v, ResultCollection)
and self.trainer is not None
and getattr(self.trainer, "lightning_module", None) is not None
):
metric_attributes = {
name: module
for name, module in self.trainer.lightning_module.named_modules()
if isinstance(module, Metric)
}
if metrics:
metric_attributes.update(metrics)

# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
)

if not self.trainer.is_global_zero:
v.reset(metrics=False)

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
33 changes: 30 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import os
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import torch
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -141,6 +142,12 @@ def restore_model(self) -> None:
# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing.
if not self.trainer.is_global_zero:
for module in self.trainer.lightning_module.modules():
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
Expand Down Expand Up @@ -341,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"epoch": current_epoch,
"global_step": global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self.trainer.accelerator.lightning_module_state_dict(),
"state_dict": self._get_lightning_module_state_dict(),
}
if _fault_tolerant_enabled():
checkpoint["loops"] = self._get_loops_state_dict()
Expand Down Expand Up @@ -443,7 +450,27 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)

def _get_loops_state_dict(self):
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
[m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
if _fault_tolerant_enabled()
else []
)

for metric in metrics:
metric.persistent(True)
metric.sync()

state_dict = self.trainer.accelerator.lightning_module_state_dict()

for metric in metrics:
# sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
if metric._is_synced:
metric.unsync()

return state_dict

def _get_loops_state_dict(self) -> Dict[str, Any]:
return {
"fit_loop": self.trainer.fit_loop.state_dict(),
"validate_loop": self.trainer.validate_loop.state_dict(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,14 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({state})"

def __getstate__(self, drop_value: bool = False) -> dict:
skip = ["update", "compute", "_update_signature"]
skip = ["update", "compute", "_update_signature", "_cache"]
if not self.is_tensor and drop_value:
# Avoid serializing ResultMetrics which are passed Metrics
skip.append("value")
d = {k: v for k, v in self.__dict__.items() if k not in skip}
d["meta"] = d["meta"].__getstate__()
d["_class"] = self.__class__.__name__
d["_is_synced"] = False # don't consider the state as synced on reload
return d

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
Expand Down Expand Up @@ -604,6 +605,16 @@ def cpu(self) -> "ResultCollection":
"""Move all data to CPU."""
return self.to(device="cpu")

def sync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor:
result_metric.sync()

def unsync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor and result_metric._is_synced:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
result_metric.unsync()

def __str__(self) -> str:
# sample output: `ResultCollection(minimize=1.23, {})`
minimize = f"minimize={self.minimize}, " if self.minimize is not None else ""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ def _log_device_info(self) -> None:
)

def _on_expection(self):
if not self.is_global_zero or not _fault_tolerant_enabled():
if not _fault_tolerant_enabled():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ tqdm>=4.41.0
PyYAML>=5.1
fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
torchmetrics>=0.4.0
torchmetrics>=0.4.1
pyDeprecate==0.3.1
packaging>=17.0
typing-extensions # TypedDict support for python<3.8
142 changes: 142 additions & 0 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from contextlib import suppress
from copy import deepcopy
from unittest import mock

import pytest
import torch
Expand All @@ -25,6 +28,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -356,6 +360,144 @@ def test_result_collection_extra_reference():
assert rc.extra is rc["_extra"]


class DummyMeanMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum)
self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum)

def update(self, increment):
self.sum += increment
self.count += 1

def compute(self):
return self.sum // self.count

def __repr__(self) -> str:
return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})"


def result_collection_reload(**kwargs):

"""
This test is going to validate ResultCollection is properly being reload
and final accumulation with Fault Tolerant Training is correct.
"""

if not _fault_tolerant_enabled():
pytest.skip("Fault tolerant not available")

num_processes = kwargs.get("gpus", 1)

class CustomException(Exception):
pass

class ExtendedBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.breaking_batch_idx = 3
self.has_validated_sum = False
self.dummy_metric = DummyMeanMetric()

@property
def results(self):
return self.trainer.fit_loop._results

def training_step(self, batch, batch_idx):

# In the training step, we will accumulate metrics using batch_idx from 0 to 4
# Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size`
# Therefore, compute on `epoch_end` should provide 2 as `10 / 5`.
# However, below we will simulate a failure on `batch_idx=3`.

if self.trainer.fit_loop.restarting:
self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)

self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)

value = self.results["training_step.tracking_metric"].value
value_2 = self.results["training_step.tracking"].value

# On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks.
# The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]`
shift = 0
if num_processes == 2:
shift = 3 if self.trainer.is_global_zero else -3
expected = sum(range(batch_idx + 1)) + shift
assert expected == value == value_2
else:
if batch_idx == self.breaking_batch_idx:
# simulate failure mid epoch
raise CustomException

self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)

self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)

value = self.results["training_step.tracking"].value
assert value == sum(range(batch_idx + 1))

value = self.results["training_step.tracking_2"]
assert value == sum(range(batch_idx + 1))

return super().training_step(batch, batch_idx)

def on_epoch_end(self) -> None:
if self.trainer.fit_loop.restarting:
total = sum(range(5)) * num_processes
metrics = self.results.metrics(on_step=False)
assert self.results["training_step.tracking"].value == total
assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2
assert self.results["training_step.tracking_2"].value == total
assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2
self.has_validated_sum = True

model = ExtendedBoringModel()
trainer_kwargs = {"max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0}
trainer_kwargs.update(kwargs)
trainer = Trainer(**trainer_kwargs)

with suppress(CustomException):
trainer.fit(model)
assert not model.has_validated_sum

tmpdir = (
trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0)
if num_processes >= 2
else trainer_kwargs["default_root_dir"]
)
ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
trainer_kwargs["resume_from_checkpoint"] = ckpt_path

trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
assert model.has_validated_sum


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload(tmpdir):
result_collection_reload(default_root_dir=tmpdir)


@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload_1_gpu_ddp(tmpdir):
result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=1)


@RunIf(min_gpus=2, special=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload_2_gpus(tmpdir):
result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=2)


def test_metric_collections(tmpdir):
"""This test ensures the metric attribute is properly found even with complex nested metric structure"""

Expand Down
Loading