Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add sync_context and sync to nn.Metric #302

Merged
merged 49 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
80575de
wip
tchaton Jun 17, 2021
904ecec
add _apply_sync to nn.Metric
tchaton Jun 17, 2021
71ca9be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
e4d99d8
move to context manager
tchaton Jun 17, 2021
94d3450
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 17, 2021
bfcbc74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
41a60e7
resolve flake8
tchaton Jun 17, 2021
b0498b4
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 17, 2021
94fab1b
add sync
tchaton Jun 17, 2021
31498ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
31563dc
update
tchaton Jun 17, 2021
7d24123
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 17, 2021
15e6d9a
update on comments
tchaton Jun 18, 2021
b3d5ec5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
d4367db
Merge branch 'master' into apply_sync_fn
tchaton Jun 18, 2021
fc42bbe
update
tchaton Jun 18, 2021
0dcb041
update
tchaton Jun 18, 2021
980329d
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
3b4e838
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
140aeeb
add restore_cache
tchaton Jun 18, 2021
1cf6a44
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
3ab11cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
ca04cfb
add a sync test
tchaton Jun 18, 2021
f2ae287
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
45b1b1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
8aa2a74
resolve flake8
tchaton Jun 18, 2021
0f2ed93
resolve loading
tchaton Jun 18, 2021
de32cbf
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
25bfbf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
7222aa3
resolve flake8
tchaton Jun 18, 2021
6dd7705
Update torchmetrics/metric.py
tchaton Jun 18, 2021
da215ad
remove _update_signature
tchaton Jun 18, 2021
dc8e699
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
fe456f2
Apply suggestions from code review
Borda Jun 18, 2021
303e829
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
586ae75
update on comments
tchaton Jun 18, 2021
a460801
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
409e20a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
71fad52
add missing is_distributed_fn
tchaton Jun 18, 2021
b7e2030
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
e72de7d
update on comments
tchaton Jun 18, 2021
11a3ab8
Update torchmetrics/metric.py
carmocca Jun 18, 2021
d9c0a53
resolve failing test
tchaton Jun 18, 2021
f37e77f
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 18, 2021
6e7e3a8
Deepsource smells
carmocca Jun 18, 2021
7ee31d0
Apply suggestions from code review
Borda Jun 21, 2021
99d99e1
update
tchaton Jun 21, 2021
0df9659
Merge branch 'apply_sync_fn' of https://github.com/PyTorchLightning/m…
tchaton Jun 21, 2021
9872ddd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from copy import deepcopy
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -116,3 +119,56 @@ def compute(self):
def test_non_contiguous_tensors():
""" Test that gather_all operation works for non contiguous tensors """
torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2)


def _test_state_dict_is_synced(rank, worldsize, tmpdir):
setup_ddp(rank, worldsize)

class DummyCatMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum)
self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum)

def update(self, x):
self.x += x
self.c += 1

def compute(self):
return self.x / self.c

metric = DummyCatMetric()
metric.persistent(True)

steps = 5
for i in range(steps):
metric(i)
state_dict = metric.state_dict()
print(state_dict)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

sum = i * (i + 1) / 2
tchaton marked this conversation as resolved.
Show resolved Hide resolved
assert state_dict["x"] == sum * worldsize
assert metric.x == sum
assert metric.c == (i + 1)
assert state_dict["c"] == metric.c * worldsize

def reload_state_dict(state_dict, expected_x, expected_c):
metric = DummyCatMetric()
metric.load_state_dict(state_dict)
assert metric.x == expected_x
assert metric.c == expected_c

with mock.patch.dict(os.environ, {"GLOBAL_RANK": str(rank)}):
reload_state_dict(deepcopy(state_dict), 20 if not rank else 0, 10 if not rank else 0)

reload_state_dict(deepcopy(state_dict), 20, 10)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_state_dict_is_synced(tmpdir):
"""
This test asserts taht metric are synced while creating the state
dict but restored after to continue accumulation.
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""
torch.multiprocessing.spawn(_test_state_dict_is_synced, args=(2, tmpdir), nprocs=2)
154 changes: 118 additions & 36 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import functools
import inspect
import operator
import os
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor, nn
Expand All @@ -28,6 +30,10 @@
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version


def is_distributed_fn() -> bool:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return torch.distributed.is_available() and torch.distributed.is_initialized()


class Metric(nn.Module, ABC):
"""
Base class for all metrics present in the Metrics API.
Expand Down Expand Up @@ -83,6 +89,7 @@ def __init__(
self.process_group = process_group
self.dist_sync_fn = dist_sync_fn
self._to_sync = True
self._restore_cache = True

self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update)
Expand Down Expand Up @@ -169,6 +176,9 @@ def forward(self, *args, **kwargs):

if self.compute_on_step:
self._to_sync = self.dist_sync_on_step
# skip restore cache operation from compute
# as cache is stored below.
Borda marked this conversation as resolved.
Show resolved Hide resolved
self._restore_cache = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
Expand All @@ -181,27 +191,31 @@ def forward(self, *args, **kwargs):
# restore context
for attr, val in cache.items():
setattr(self, attr, val)

self._restore_cache = True
self._to_sync = True
self._computed = None

return self._forward_cache

def _sync_dist(self, dist_sync_fn=gather_all_tensors):
def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None):
input_dict = {attr: getattr(self, attr) for attr in self._reductions}

for attr, reduction_fn in self._reductions.items():
# pre-concatenate metric states that are lists to reduce number of all_gather operations
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

output_dict = apply_to_collection(
input_dict,
Tensor,
dist_sync_fn,
group=self.process_group,
group=process_group or self.process_group,
)

for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], Tensor):
if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is not safe. we're seeing errors as a result.

if isinstance(output_dict[attr], Sequence) and isinstance(output_dict[attr][0], Tensor):

IndexError: list index out of range

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in #311

output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
Expand All @@ -221,6 +235,78 @@ def wrapped_func(*args, **kwargs):

return wrapped_func

def sync(
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused to see should_sync=True|False.

If you set False, this method does nothing, so it's the same as not calling sync in the first place!
Then, if you set True but dist is not available, it will do nothing so basically it does not what the user wants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should_sync means should_sync if possible :) Modified the docstring to reflect this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means now that there are two arguments overlapping: dist_sync_fn and should_sync

You can do this: should_sync=False and dist_sync_fn=mean

what willl happen now? will it sync or not?
@PyTorchLightning/core-metrics be aware of these cases

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. I wonder if the main usage of should_sync is just in sync_context and maybe we should just decide there if syncing is needed or not? Doing an if with context manager is a bit harder and might justify a flag, but for a function, it should be easy for people to just avoid calling it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximsch2 your argument is to keep the flag for the context manager but remove it from this function, correct?

I think that would be fine.

is_distributed_fn: Optional[Callable] = is_distributed_fn,
) -> Dict[str, Tensor]:
"""
Sync function for manually controlling when metrics states should be synced across processes

Args:
dist_sync_fn: Function to be used to perform states synchronization
process_group:
Specify the process group on which synchronization is called.
default: None (which selects the entire world)
should_sync: Whether to apply to state synchronization.

Returns:
cache: A dictionarry containing the local metric states. The cache will be empty if sync didn't happen.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
is_distributed = is_distributed_fn()

if dist_sync_fn is None:
dist_sync_fn = gather_all_tensors

cache = {}

if is_distributed and should_sync:
# cache prior to syncing
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)

tchaton marked this conversation as resolved.
Show resolved Hide resolved
return cache

@contextmanager
def sync_context(
self,
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
restore_cache: bool = True,
is_distributed_fn: Optional[Callable] = is_distributed_fn,
) -> None:
"""
Context manager to synchronize the states between processes when running in a distributed setting
and restore the local cache states after yielding.

Args:
dist_sync_fn: Function to be used to perform states synchronization
process_group:
Specify the process group on which synchronization is called.
default: None (which selects the entire world)
should_sync: Whether to apply to state synchronization.
restore_cache: Whether to restore the cache state so that the metrics can
continue to be accumulated.
"""
cache = self.sync(
dist_sync_fn=dist_sync_fn,
process_group=process_group,
should_sync=should_sync,
is_distributed_fn=is_distributed_fn
)

yield

if cache and restore_cache:
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in cache.items():
setattr(self, attr, val)

Comment on lines +303 to +307
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use case for this? If we sync, we should assume all metrics are operating off the synced state and not accumulate local changes, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was here already, just moved.

Added in dd1e744
cc: @SkafteNicki

def _wrap_compute(self, compute):

@functools.wraps(compute)
Expand All @@ -236,26 +322,10 @@ def wrapped_func(*args, **kwargs):
if self._computed is not None:
return self._computed

dist_sync_fn = self.dist_sync_fn
if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized():
# User provided a bool, so we assume DDP if available
dist_sync_fn = gather_all_tensors

synced = False
cache = []
if self._to_sync and dist_sync_fn is not None:
# cache prior to syncing
cache = {attr: getattr(self, attr) for attr in self._defaults}

# sync
self._sync_dist(dist_sync_fn)
synced = True

self._computed = compute(*args, **kwargs)
if synced:
# if we synced, restore to cache so that we can continue to accumulate un-synced state
for attr, val in cache.items():
setattr(self, attr, val)
with self.sync_context(
dist_sync_fn=self.dist_sync_fn, should_sync=self._to_sync, restore_cache=self._restore_cache
):
self._computed = compute(*args, **kwargs)

return self._computed

Expand Down Expand Up @@ -299,11 +369,12 @@ def clone(self):

def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute", "_update_signature"]}

def __setstate__(self, state):
# manually restore update and compute functions for pickling
self.__dict__.update(state)
self._update_signature = inspect.signature(self.update)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

Expand Down Expand Up @@ -341,16 +412,24 @@ def persistent(self, mode: bool = False):
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# Register metric states to be part of the state_dict
for key in self._defaults:
if self._persistent[key]:
current_val = getattr(self, key)
if not keep_vars:
if torch.is_tensor(current_val):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val]
destination[prefix + key] = current_val
return destination
with self.sync_context(dist_sync_fn=self.dist_sync_fn):
for key in self._defaults:
if self._persistent[key]:
current_val = getattr(self, key)
if not keep_vars:
if isinstance(current_val, torch.Tensor):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [
cur_v.detach() if isinstance(cur_v, torch.Tensor) else cur_v for cur_v in current_val
]
destination[prefix + key] = deepcopy(current_val)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return destination

def _on_load_from_state_dict(self, state_dict, key, name) -> None:
value = state_dict.pop(name)
if os.getenv("GLOBAL_RANK", "0") == "0":
setattr(self, key, value)
Borda marked this conversation as resolved.
Show resolved Hide resolved

def _load_from_state_dict(
self,
Expand All @@ -363,10 +442,13 @@ def _load_from_state_dict(
error_msgs: List[str],
) -> None:
""" Loads metric states from state_dict """

# only global rank 0 should be reloading the values present in the ``state_dict``
# as the state contains synced values across all progress_group
for key in self._defaults:
name = prefix + key
if name in state_dict:
setattr(self, key, state_dict.pop(name))
self._on_load_from_state_dict(state_dict, key, name)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
Expand Down