Skip to content

Commit

Permalink
[New metric]: Running (#1752)
Browse files Browse the repository at this point in the history
* implementation

* docs

* tests

* changelog

* fix skipping doctests

* fix unittest

* Apply suggestions from code review

Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: Daniel Stancl <[email protected]>

* fix unittests

* add common runners

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: Daniel Stancl <[email protected]>
  • Loading branch information
5 people authored May 10, 2023
1 parent 41bb1fa commit 47c6d1c
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `Running` wrapper for calculate running statistics ([#1752](https://github.com/Lightning-AI/torchmetrics/pull/1752))


- Added`RelativeAverageSpectralError` and `RootMeanSquaredErrorUsingSlidingWindow` to image package ([#816](https://github.com/PyTorchLightning/metrics/pull/816))


Expand Down
17 changes: 17 additions & 0 deletions docs/source/aggregation/running_mean.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. customcarditem::
:header: Mean
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg
:tags: Aggregation

.. include:: ../links.rst

############
Running Mean
############

Module Interface
________________

.. autoclass:: torchmetrics.RunningMean
:noindex:
:exclude-members: update, compute
17 changes: 17 additions & 0 deletions docs/source/aggregation/running_sum.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. customcarditem::
:header: Sum
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg
:tags: Aggregation

.. include:: ../links.rst

###########
Running Sum
###########

Module Interface
________________

.. autoclass:: torchmetrics.RunningSum
:noindex:
:exclude-members: update, compute
17 changes: 17 additions & 0 deletions docs/source/wrappers/running.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. customcarditem::
:header: Running
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/graph_classification.svg
:tags: Wrappers

.. include:: ../links.rst

#######
Running
#######

Module Interface
________________

.. autoclass:: torchmetrics.wrappers.Running
:noindex:
:exclude-members: update, compute
10 changes: 9 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.aggregation import ( # noqa: E402
CatMetric,
MaxMetric,
MeanMetric,
MinMetric,
RunningMean,
RunningSum,
SumMetric,
)
from torchmetrics.audio._deprecated import _PermutationInvariantTraining as PermutationInvariantTraining # noqa: E402
from torchmetrics.audio._deprecated import ( # noqa: E402
_ScaleInvariantSignalDistortionRatio as ScaleInvariantSignalDistortionRatio,
Expand Down
113 changes: 113 additions & 0 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.running import Running

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"]
Expand Down Expand Up @@ -568,3 +569,115 @@ def plot(
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class RunningMean(Running):
"""Aggregate a stream of value into their mean over a running window.
Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
training and don't want to wait for the whole training to finish to get epoch level estimates.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
Args:
window: The size of the running window.
nan_strategy: options:
- ``'error'``: if any `nan` values are encounted will give a RuntimeError
- ``'warn'``: if any `nan` values are encounted will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningMean
>>> metric = RunningMean(window=3)
>>> for i in range(6):
... current_val = metric(tensor([i]))
... running_val = metric.compute()
... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples
... print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.)
current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000)
current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.)
current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000)
current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.)
current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000)
"""

def __init__(
self,
window: int = 5,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window)


class RunningSum(Running):
"""Aggregate a stream of value into their sum over a running window.
Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead
of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
training and don't want to wait for the whole training to finish to get epoch level estimates.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
arbitary shape ``(...,)``.
As output of `forward` and `compute` the metric returns the following output
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
Args:
window: The size of the running window.
nan_strategy: options:
- ``'error'``: if any `nan` values are encounted will give a RuntimeError
- ``'warn'``: if any `nan` values are encounted will give a warning and continue
- ``'ignore'``: all `nan` values are silently removed
- a float: if a float is provided will impude any `nan` values with this value
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
Example:
>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningSum
>>> metric = RunningSum(window=3)
>>> for i in range(6):
... current_val = metric(tensor([i]))
... running_val = metric.compute()
... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples
... print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0)
current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1)
current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3)
current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6)
current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10)
current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
"""

def __init__(
self,
window: int = 5,
nan_strategy: Union[str, float] = "warn",
**kwargs: Any,
) -> None:
super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window)
1 change: 1 addition & 0 deletions src/torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.wrappers.classwise import ClasswiseWrapper
from torchmetrics.wrappers.minmax import MinMaxMetric
from torchmetrics.wrappers.multioutput import MultioutputWrapper
from torchmetrics.wrappers.running import Running
from torchmetrics.wrappers.tracker import MetricTracker

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import Tensor

from torchmetrics import Metric
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/wrappers/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
from torch.nn import ModuleList

from torchmetrics import Metric
from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down
Loading

0 comments on commit 47c6d1c

Please sign in to comment.