Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metric tracker #238

Merged
merged 36 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a05c6c9
added initial code / definition for timeseries class
rajs96 May 7, 2021
e9cab00
added base essential methods, some implementation
rajs96 May 9, 2021
0b61584
something
rajs96 May 9, 2021
e6d18b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2021
0b1fbab
add proper typing import
rajs96 May 9, 2021
8900fe5
merge
rajs96 May 9, 2021
2863e03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2021
24c8768
move typing imports to top of file
rajs96 May 9, 2021
e6ea61c
alphabetize typing imports
rajs96 May 9, 2021
7e57870
add implementation for best metric
rajs96 May 10, 2021
895d480
short
Borda May 10, 2021
2f239e5
Merge branch 'master' into feature/180_timeseries
Borda Jun 10, 2021
3d592c5
Merge branch 'master' into feature/180_timeseries
SkafteNicki Jul 8, 2021
be24597
changes
SkafteNicki Jul 12, 2021
d3111bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
7a7b98e
Merge branch 'master' into feature/180_timeseries
SkafteNicki Jul 12, 2021
878a972
changelog
SkafteNicki Jul 12, 2021
6c0674e
Merge branch 'master' into feature/180_timeseries
SkafteNicki Jul 22, 2021
5d6fc9e
Merge branch 'master' into feature/180_timeseries
Borda Jul 26, 2021
90b6255
Merge branch 'master' into feature/180_timeseries
SkafteNicki Jul 29, 2021
7d9d2a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2021
968f83e
fix flake8
SkafteNicki Jul 29, 2021
39dd76f
add typing
SkafteNicki Jul 29, 2021
3cafa7e
Merge branch 'master' into feature/180_timeseries
SkafteNicki Jul 29, 2021
d13fe08
Merge branch 'master' into feature/180_timeseries
SkafteNicki Aug 2, 2021
0bfc7d2
Merge branch 'master' into feature/180_timeseries
SkafteNicki Aug 3, 2021
f4f04af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
087d350
Merge branch 'master' into feature/180_timeseries
Borda Aug 4, 2021
69fbf4d
Apply suggestions from code review
Borda Aug 4, 2021
40d4ba4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
357daf5
Apply suggestions from code review
Borda Aug 4, 2021
3f90e90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2021
85d92e8
Apply suggestions from code review
Borda Aug 4, 2021
7b92cf7
Apply suggestions from code review
Borda Aug 4, 2021
d52f563
docs
Borda Aug 4, 2021
f4a5205
Merge branch 'master' into feature/180_timeseries
mergify[bot] Aug 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378))


- Added `MetricTracker` wrapper metric for keeping track of the same metric over multiple epochs ([#238](https://github.com/PyTorchLightning/metrics/pull/238))

Borda marked this conversation as resolved.
Show resolved Hide resolved

### Changed

- Moved `psnr` and `ssim` from `functional.regression.*` to `functional.image.*` ([#382](https://github.com/PyTorchLightning/metrics/pull/382))
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,9 @@ BootStrapper

.. autoclass:: torchmetrics.BootStrapper
:noindex:

MetricTracker
~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MetricTracker
:noindex:
76 changes: 76 additions & 0 deletions tests/wrappers/test_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import pytest
import torch

from tests.helpers import seed_all
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, Precision, Recall
from torchmetrics.wrappers import MetricTracker

seed_all(42)


def test_raises_error_on_wrong_input():
with pytest.raises(TypeError, match="metric arg need to be an instance of a torchmetrics metric .*"):
MetricTracker([1, 2, 3])


@pytest.mark.parametrize(
"method, method_input",
[
("update", (torch.randint(10, (50,)), torch.randint(10, (50,)))),
("forward", (torch.randint(10, (50,)), torch.randint(10, (50,)))),
("compute", None),
],
)
def test_raises_error_if_increment_not_called(method, method_input):
with pytest.raises(ValueError, match=f"`{method}` cannot be called before .*"):
tracker = MetricTracker(Accuracy(num_classes=10))
Borda marked this conversation as resolved.
Show resolved Hide resolved
if method_input is not None:
getattr(tracker, method)(*method_input)
else:
getattr(tracker, method)()


@pytest.mark.parametrize(
"base_metric, metric_input, maximize",
[
(partial(Accuracy, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(partial(Precision, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(partial(Recall, num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
(MeanSquaredError, (torch.randn(50), torch.randn(50)), False),
(MeanAbsoluteError, (torch.randn(50), torch.randn(50)), False),
],
)
def test_tracker(base_metric, metric_input, maximize):
tracker = MetricTracker(base_metric(), maximize=maximize)
for i in range(5):
tracker.increment()
# check both update and forward works
for _ in range(5):
tracker.update(*metric_input)
for _ in range(5):
tracker(*metric_input)

val = tracker.compute()
assert val != 0.0
assert tracker.n_steps == i + 1

assert tracker.n_steps == 5
assert tracker.compute_all().shape[0] == 5
val, idx = tracker.best_metric(return_step=True)
assert val != 0.0
assert idx in list(range(5))
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
RetrievalRecall,
)
from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402
from torchmetrics.wrappers import BootStrapper # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402

__all__ = [
"functional",
Expand Down Expand Up @@ -98,6 +98,7 @@
"MeanSquaredLogError",
"Metric",
"MetricCollection",
"MetricTracker",
"PearsonCorrcoef",
"PIT",
"Precision",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401
119 changes: 119 additions & 0 deletions torchmetrics/wrappers/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Tuple, Union

import torch
from torch import Tensor, nn

from torchmetrics.metric import Metric


class MetricTracker(nn.ModuleList):
"""A wrapper class that can help keeping track of a metric over time and implement useful methods. The wrapper
implements the standard `update`, `compute`, `reset` methods that just calls corresponding method of the
currently tracked metric. However, the following additional methods are provided:

-``MetricTracker.n_steps``: number of metrics being tracked

-``MetricTracker.increment()``: initialize a new metric for being tracked

-``MetricTracker.compute_all()``: get the metric value for all steps

-``MetricTracker.best_metric()``: returns the best value

Args:
metric: instance of a torchmetric modular to keep track of at each timestep.
maximize: bool indicating if higher metric values are better (`True`) or lower
is better (`False`)

Example::
>>> from torchmetrics import Accuracy, MetricTracker
>>> tracker = MetricTracker(Accuracy(num_classes=10))
>>> for epoch in range(5): # doctest: +SKIP
... tracker.increment() # doctest: +SKIP
... for batch_idx in range(5): # doctest: +SKIP
... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) # doctest: +SKIP
... tracker.update(preds, target) # doctest: +SKIP
... print(f"current acc={tracker.compute()}") # doctest: +SKIP
>>> best_acc, which_epoch = tracker.best_metric(return_step=True) # doctest: +SKIP
>>> all_values = tracker.compute_all() # doctest: +SKIP
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, metric: Metric, maximize: bool = True) -> None:
super().__init__()
if not isinstance(metric, Metric):
raise TypeError("metric arg need to be an instance of a torchmetrics metric" f" but got {metric}")
self._base_metric = metric
self.maximize = maximize

self._increment_called = False

@property
def n_steps(self) -> int:
"""Returns the number of times the tracker has been incremented."""
return len(self) - 1 # subtract the base metric

def increment(self) -> None:
"""Creates a new instace of the input metric that will be updated next."""
self._increment_called = True
self.append(deepcopy(self._base_metric))

def forward(self, *args, **kwargs) -> None: # type: ignore
"""Calls forward of the current metric being tracked."""
self._check_for_increment("forward")
return self[-1](*args, **kwargs)

def update(self, *args, **kwargs) -> None: # type: ignore
"""Updates the current metric being tracked."""
self._check_for_increment("update")
self[-1].update(*args, **kwargs)

def compute(self) -> Any:
"""Call compute of the current metric being tracked."""
self._check_for_increment("compute")
return self[-1].compute()

def compute_all(self) -> Tensor:
"""Compute the metric value for all tracked metrics."""
self._check_for_increment("compute_all")
return torch.stack([metric.compute() for i, metric in enumerate(self) if i != 0], dim=0)

def reset(self) -> None:
"""Resets the current metric being tracked."""
self[-1].reset()

def reset_all(self) -> None:
"""Resets all metrics being tracked."""
for metric in self:
metric.reset()

def best_metric(self, return_step: bool = False) -> Union[float, Tuple[int, float]]:
"""Returns the highest metric out of all tracked.

Args:
return_step: If `True` will also return the step with the highest metric value.

Returns:
The best metric value, and optionally the timestep.
"""
fn = torch.max if self.maximize else torch.min
idx, max = fn(self.compute_all(), 0)
if return_step:
return idx.item(), max.item()
return max.item()

def _check_for_increment(self, method: str) -> None:
if not self._increment_called:
raise ValueError(f"`{method}` cannot be called before `.increment()` has been called")