Skip to content

Commit

Permalink
Specify usage of metric tracker (#1608)
Browse files Browse the repository at this point in the history
* fixing
* docs
* changelog
* mypy issue
  • Loading branch information
SkafteNicki authored Mar 11, 2023
1 parent 78e9571 commit 3f74920
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed support in `MetricTracker` for `MultioutputWrapper` and nested structures ([#1608](https://github.com/Lightning-AI/metrics/pull/1608))


## [0.11.4] - 2023-03-10
Expand Down
46 changes: 33 additions & 13 deletions src/torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class MetricTracker(ModuleList):
-``MetricTracker.compute_all()``: get the metric value for all steps
-``MetricTracker.best_metric()``: returns the best value
Out of the box, this wrapper class fully supports that the base metric being tracked is a single `Metric`, a
`MetricCollection` or another `MetricWrapper` wrapped around a metric. However, multiple layers of nesting, such
as using a `Metric` inside a `MetricWrapper` inside a `MetricCollection` is not fully supported, especially the
`.best_metric` method that cannot auto compute the best metric and index for such nested structures.
Args:
metric: instance of a ``torchmetrics.Metric`` or ``torchmetrics.MetricCollection``
to keep track of at each timestep.
Expand Down Expand Up @@ -135,12 +140,13 @@ def compute(self) -> Any:
self._check_for_increment("compute")
return self[-1].compute()

def compute_all(self) -> Union[Tensor, Dict[str, Tensor]]:
def compute_all(self) -> Any:
"""Compute the metric value for all tracked metrics.
Return:
Either a single tensor if the tracked base object is a single metric, else if a metric collection is
provide a dict of tensors will be returned
By default will try stacking the results from all increaments into a single tensor if the tracked base
object is a single metric. If a metric collection is provided a dict of stacked tensors will be returned.
If the stacking process fails a list of the computed results will be returned.
Raises:
ValueError:
Expand All @@ -149,10 +155,15 @@ def compute_all(self) -> Union[Tensor, Dict[str, Tensor]]:
self._check_for_increment("compute_all")
# The i!=0 accounts for the self._base_metric should be ignored
res = [metric.compute() for i, metric in enumerate(self) if i != 0]
if isinstance(self._base_metric, MetricCollection):
keys = res[0].keys()
return {k: torch.stack([r[k] for r in res], dim=0) for k in keys}
return torch.stack(res, dim=0)
try:
if isinstance(res[0], dict):
keys = res[0].keys()
return {k: torch.stack([r[k] for r in res], dim=0) for k in keys}
if isinstance(res[0], list):
return torch.stack([torch.stack(r, dim=0) for r in res], 0)
return torch.stack(res, dim=0)
except TypeError: # fallback solution to just return as it is if we cannot succesfully stack
return res

def reset(self) -> None:
"""Reset the current metric being tracked."""
Expand Down Expand Up @@ -181,7 +192,6 @@ def best_metric(
Returns:
Either a single value or a tuple, depends on the value of ``return_step`` and the object being tracked.
- If a single metric is being tracked and ``return_step=False`` then a single tensor will be returned
- If a single metric is being tracked and ``return_step=True`` then a 2-element tuple will be returned,
where the first value is optimal value and second value is the corresponding optimal step
Expand All @@ -193,12 +203,23 @@ def best_metric(
of the first dict being the optimal values and the values of the second dict being the optimal step
In addtion the value in all cases may be ``None`` if the underlying metric does have a proper defined way
of being optimal.
of being optimal or in the case where a nested structure of metrics are being tracked.
"""
res = self.compute_all()
if isinstance(res, list):
rank_zero_warn(
"Encounted nested structure. You are probably using a metric collection inside a metric collection, or"
" a metric wrapper inside a metric collection, which is not supported by `.best_metric()` method."
"Returning `None` instead. Please consider "
)
if return_step:
return None, None
return None

if isinstance(self._base_metric, Metric):
fn = torch.max if self.maximize else torch.min
try:
value, idx = fn(self.compute_all(), 0) # type: ignore[arg-type]
value, idx = fn(res, 0)
if return_step:
return value.item(), idx.item()
return value.item()
Expand All @@ -214,10 +235,9 @@ def best_metric(
return None

else: # this is a metric collection
res = self.compute_all()
maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize]
value, idx = {}, {}
for i, (k, v) in enumerate(res.items()): # type: ignore[union-attr]
for i, (k, v) in enumerate(res.items()):
try:
fn = torch.max if maximize[i] else torch.min
out = fn(v, 0)
Expand All @@ -238,4 +258,4 @@ def best_metric(
def _check_for_increment(self, method: str) -> None:
"""Check that a metric that can be updated/used for computations has been intialized."""
if not self._increment_called:
raise ValueError(f"`{method}` cannot be called before `.increment()` has been called")
raise ValueError(f"`{method}` cannot be called before `.increment()` has been called.")
32 changes: 31 additions & 1 deletion tests/unittests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import torch

from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection, MultioutputWrapper
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassConfusionMatrix,
Expand Down Expand Up @@ -184,3 +184,33 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric):
else:
assert best is None
assert idx is None


@pytest.mark.parametrize(
("input_to_tracker", "assert_type"),
[
(MultioutputWrapper(MeanSquaredError(), num_outputs=2), torch.Tensor),
( # nested version
MetricCollection(
{
"mse": MultioutputWrapper(MeanSquaredError(), num_outputs=2),
"mae": MultioutputWrapper(MeanAbsoluteError(), num_outputs=2),
}
),
list,
),
],
)
def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type):
"""Check that MetricTracker support wrapper inputs and nested structures."""
tracker = MetricTracker(input_to_tracker)
for _ in range(5):
tracker.increment()
for _ in range(5):
preds, target = torch.randn(100, 2), torch.randn(100, 2)
tracker.update(preds, target)
all_res = tracker.compute_all()
assert isinstance(all_res, assert_type)
best_metric, which_epoch = tracker.best_metric(return_step=True)
assert best_metric is None
assert which_epoch is None

0 comments on commit 3f74920

Please sign in to comment.