Skip to content

Commit

Permalink
Fix state reference in MetricCollection (#1076)
Browse files Browse the repository at this point in the history
* update
* fix states
* docstring
* changelog
  • Loading branch information
SkafteNicki authored Jun 8, 2022
1 parent 65594f1 commit 3a5b24e
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 61 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069))


- Fixed bug related to state reference in metric collection when using compute groups ([#1076](https://github.com/PyTorchLightning/metrics/pull/1076))


## [0.9.0] - 2022-05-30

### Added
Expand Down
13 changes: 11 additions & 2 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.utils.data import DataLoader

from integrations.lightning.boring_model import BoringModel, RandomDataset
from tests.helpers.utilities import no_warning_call
from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric


Expand Down Expand Up @@ -210,7 +211,11 @@ def training_epoch_end(self, outs):
max_epochs=2,
log_every_n_steps=1,
)
trainer.fit(model)
with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["sum_step"]), model.sum, atol=2e-4)
Expand Down Expand Up @@ -249,7 +254,11 @@ def training_epoch_end(self, outputs):
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)
with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum, atol=2e-4)
Expand Down
122 changes: 78 additions & 44 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,53 +322,87 @@ def compute(self):
),
],
)
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups(metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called
class TestComputeGroups:
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert m.compute_groups == expected
assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()
for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, method):
"""Check that whenever user call a methods that give access to the indivitual metric that state are copied
instead of just passed by reference."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

for _ in range(2): # repeat to emulate effect of multiple epochs
for _ in range(2): # repeat to emulate effect of multiple batches
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

def _compare(m1, m2):
for state in m1._defaults:
assert torch.allclose(getattr(m1, state), getattr(m2, state))
# if states are still by reference the reset will make following metrics fail
m1.reset()
m2.reset()

if method == "items":
for (name_cg, metric_cg), (name_no_cg, metric_no_cg) in zip(m.items(), m2.items()):
assert name_cg == name_no_cg
_compare(metric_cg, metric_no_cg)
if method == "values":
for metric_cg, metric_no_cg in zip(m.values(), m2.values()):
_compare(metric_cg, metric_no_cg)
if method == "keys":
for key in m.keys():
metric_cg, metric_no_cg = m[key], m2[key]
_compare(metric_cg, metric_no_cg)


@pytest.mark.parametrize(
Expand Down
73 changes: 59 additions & 14 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class name as key for the output dict.
metric state and are therefore only different in their compute step e.g. accuracy, precision and recall
can all be computed from the true positives/negatives and false positives/negatives. By default,
this argument is ``True`` which enables this feature. Set this argument to `False` for disabling
this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself.
this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself.
.. note::
Metric collections can be nested at initilization (see last example) but the output of the collection will
still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection.
still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection.
Raises:
ValueError:
Expand Down Expand Up @@ -143,6 +143,7 @@ def __init__(
self.postfix = self._check_arg(postfix, "postfix")
self._enable_compute_groups = compute_groups
self._groups_checked: bool = False
self._state_is_copy: bool = False

self.add_metrics(metrics, *additional_metrics)

Expand All @@ -153,7 +154,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)}
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

Expand All @@ -172,13 +173,19 @@ def update(self, *args: Any, **kwargs: Any) -> None:
for i in range(1, len(cg)): # copy over the update count
mi = getattr(self, cg[i])
mi._update_count = m0._update_count
if self._state_is_copy:
# If we have deep copied state inbetween updates, reestablish link
self._compute_groups_create_state_ref()
self._state_is_copy = False
else: # the first update always do per metric to form compute groups
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)

if self._enable_compute_groups:
self._merge_compute_groups()
# create reference between states
self._compute_groups_create_state_ref()
self._groups_checked = True

def _merge_compute_groups(self) -> None:
Expand Down Expand Up @@ -241,24 +248,37 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:

return True

def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
if self._enable_compute_groups and self._groups_checked:
def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
"""Create reference between metrics in the same compute group.
Args:
copy: If `True` the metric state will between members will be copied instead
of just passed by reference
"""
if not self._state_is_copy:
for _, cg in self._groups.items():
m0 = getattr(self, cg[0])
# copy the state to the remaining metrics in the compute group
for i in range(1, len(cg)):
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))
res = {k: m.compute() for k, m in self.items(keep_base=True)}
m0_state = getattr(m0, state)
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
res = {k: m.compute() for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

def reset(self) -> None:
"""Iteratively call reset for each metric."""
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m.reset()
if self._enable_compute_groups and self._groups_checked:
# reset state reference
self._compute_groups_create_state_ref()

def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection":
"""Make a copy of the metric collection
Expand All @@ -276,7 +296,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) ->

def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to its state_dict."""
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m.persistent(mode)

def add_metrics(
Expand Down Expand Up @@ -388,15 +408,40 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()

def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]:
def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
keep_base: Whether to add prefix/postfix on the collection.
copy_state:
If metric states should be copied between metrics in the same compute group or just passed by reference
"""
self._compute_groups_create_state_ref(copy_state)
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()

def values(self, copy_state: bool = True) -> Iterable[Module]:
"""Return an iterable of the ModuleDict values.
Args:
copy_state:
If metric states should be copied between metrics in the same compute group or just passed by reference
"""
self._compute_groups_create_state_ref(copy_state)
return self._modules.values()

def __getitem__(self, key: str, copy_state: bool = True) -> Module:
"""Retrieve a single metric from the collection.
Args:
key: name of metric to retrieve
copy_state:
If metric states should be copied between metrics in the same compute group or just passed by reference
"""
self._compute_groups_create_state_ref(copy_state)
return self._modules[key]

@staticmethod
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

# reduce batch and global state
self._update_count = _update_count + 1
self._reduce_states(global_state)
with torch.no_grad():
self._reduce_states(global_state)

# restore context
self._is_synced = False
Expand Down

0 comments on commit 3a5b24e

Please sign in to comment.