From cd7d7884cc84a219c7ed1779da291ba9a6bcb6e9 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 14:47:18 +0200 Subject: [PATCH 01/26] update --- tests/bases/test_aggregation.py | 161 ++++++++++++++ tests/bases/test_average.py | 88 -------- torchmetrics/__init__.py | 7 +- torchmetrics/aggregation.py | 379 ++++++++++++++++++++++++++++++++ torchmetrics/average.py | 109 --------- torchmetrics/metric.py | 10 +- torchmetrics/utilities/data.py | 8 + 7 files changed, 561 insertions(+), 201 deletions(-) create mode 100644 tests/bases/test_aggregation.py delete mode 100644 tests/bases/test_average.py create mode 100644 torchmetrics/aggregation.py delete mode 100644 torchmetrics/average.py diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py new file mode 100644 index 00000000000..3af8807911a --- /dev/null +++ b/tests/bases/test_aggregation.py @@ -0,0 +1,161 @@ +from multiprocessing import Value +import numpy as np +import pytest +import torch + +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics.aggregation import MinMetric, MaxMetric, MeanMetric, SumMetric, CatMetric + + +def compare_mean(values, weights): + return np.average(values.numpy(), weights=weights) + + +def compare_sum(values, weights): + return np.sum(values.numpy()) + + +def compare_min(values, weights): + return np.min(values.numpy()) + + +def compare_max(values, weights): + return np.max(values.numpy()) + + +def compare_cat(values, weights): + return np.concatenate(values.numpy()) + + +# wrap all other than mean metric to take an additional argument +# this lets them fit into the testing framework +class WrappedMinMetric(MinMetric): + def update(self, values, weights): + super().update(values) + + +class WrappedMaxMetric(MaxMetric): + def update(self, values, weights): + super().update(values) + + +class WrappedSumMetric(SumMetric): + def update(self, values, weights): + super().update(values) + + +class WrappedCatMetric(CatMetric): + def update(self, values, weights): + super().update(values) + + +@pytest.mark.parametrize( + "values, weights", + [ + (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)), + (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5), + (torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5), + ], +) +@pytest.mark.parametrize( + "metric_class, compare_fn", + [ + (WrappedMinMetric, compare_min), + (WrappedMaxMetric, compare_max), + (WrappedSumMetric, compare_sum), + (WrappedCatMetric, compare_cat), + (MeanMetric, compare_mean), + ] +) +class TestAggregation(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): + self.run_class_metric_test( + ddp=ddp, + dist_sync_on_step=dist_sync_on_step, + metric_class=metric_class, + sk_metric=compare_fn, + # Abuse of names here + preds=values, + target=weights, + ) + + def test_aggregation_differentiability(self, metric_class, compare_fn, values, weights): + self.run_differentiability_test(preds=values, target=weights, metric_module=metric_class) + + def test_aggregation_half_cpu(self, metric_class, compare_fn, values, weights): + if metric_class in (WrappedMinMetric, WrappedMaxMetric): + pytest.skip("MinMetric and MaxMetric does not support half dtype on cpu") + self.run_precision_test_cpu(preds=values, target=weights, metric_module=metric_class) + + def test_aggregation_half_gpu(self, metric_class, compare_fn, values, weights): + self.run_precision_test_gpu(preds=values, target=weights, metric_module=metric_class) + + +_case1 = float("nan") * torch.ones(5) +_case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) + +@pytest.mark.parametrize("value", [_case1, _case2]) +@pytest.mark.parametrize("nan_strategy", ['error', 'warn']) +@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) +def test_nan_error(value, nan_strategy, metric_class): + metric = metric_class(nan_strategy=nan_strategy) + if nan_strategy == 'error': + with pytest.raises(RuntimeError, match='Encounted `nan` values in tensor'): + metric(value.clone()) + elif nan_strategy == 'warn': + with pytest.warns(UserWarning, match='Encounted `nan` values in tensor'): + metric(value.clone()) + + +@pytest.mark.parametrize("metric_class, nan_strategy, value, expected", + [ + (MinMetric, 'ignore', _case1, torch.tensor(float("inf"))), + (MinMetric, 2.0, _case1, 2.0), + (MinMetric, 'ignore', _case2, 1.0), + (MinMetric, 2.0, _case2, 1.0), + + (MaxMetric, 'ignore', _case1, -torch.tensor(float("inf"))), + (MaxMetric, 2.0, _case1, 2.0), + (MaxMetric, 'ignore', _case2, 5.0), + (MaxMetric, 2.0, _case2, 5.0), + + (SumMetric, 'ignore', _case1, 0.0), + (SumMetric, 2.0, _case1, 10.0), + (SumMetric, 'ignore', _case2, 12.0), + (SumMetric, 2.0, _case2, 14.0), + + (MeanMetric, 'ignore', _case1, torch.tensor([float("nan")])), + (MeanMetric, 2.0, _case1, 2.0), + (MeanMetric, 'ignore', _case2, 3.0), + (MeanMetric, 2.0, _case2, 2.8), + + (CatMetric, 'ignore', _case1, []), + (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), + (CatMetric, 'ignore', _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), + (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), + ] +) +def test_nan_expected(metric_class, nan_strategy, value, expected): + metric = metric_class(nan_strategy=nan_strategy) + metric.update(value.clone()) + out = metric.compute() + assert np.allclose(out, expected, equal_nan=True) + + +@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) +def test_error_on_wrong_nan_strategy(metric_class): + with pytest.raises(ValueError, match='Arg `nan_strategy` should either .*'): + metric_class(nan_strategy=[]) + + +@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to") +@pytest.mark.parametrize( + "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] +) +def test_mean_metric_broadcasting(weights, expected): + values = torch.arange(24).reshape(2, 3, 4) + avg = MeanMetric() + + assert avg(values, weights) == expected diff --git a/tests/bases/test_average.py b/tests/bases/test_average.py deleted file mode 100644 index 9c84caf8ddc..00000000000 --- a/tests/bases/test_average.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import pytest -import torch - -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.average import AverageMeter - - -def average(values, weights): - return np.average(values, weights=weights) - - -def average_ignore_weights(values, weights): - return np.average(values) - - -class DefaultWeightWrapper(AverageMeter): - def update(self, values, weights): - super().update(values) - - -class ScalarWrapper(AverageMeter): - def update(self, values, weights): - # torch.ravel is PyTorch 1.8 only, so use np.ravel instead - values = values.cpu().numpy() - weights = weights.cpu().numpy() - for v, w in zip(np.ravel(values), np.ravel(weights)): - super().update(float(v), float(w)) - - -@pytest.mark.parametrize( - "values, weights", - [ - (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)), - (torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5), - (torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5), - ], -) -class TestAverageMeter(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=AverageMeter, - sk_metric=average, - # Abuse of names here - preds=values, - target=weights, - ) - - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn_default(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=DefaultWeightWrapper, - sk_metric=average_ignore_weights, - # Abuse of names here - preds=values, - target=weights, - ) - - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_average_fn_scalar(self, ddp, dist_sync_on_step, values, weights): - self.run_class_metric_test( - ddp=ddp, - dist_sync_on_step=dist_sync_on_step, - metric_class=ScalarWrapper, - sk_metric=average, - # Abuse of names here - preds=values, - target=weights, - ) - - -@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to") -@pytest.mark.parametrize( - "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] -) -def test_AverageMeter_broadcasting(weights, expected): - values = torch.arange(24).reshape(2, 3, 4) - avg = AverageMeter() - - assert avg(values, weights) == expected diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5f7798a56e6..c9629f3f8ef 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -13,7 +13,7 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402 -from torchmetrics.average import AverageMeter # noqa: E402 +from torchmetrics.aggregation import CatMetric, MinMetric, MaxMetric, MeanMetric, SumMetric # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUC, AUROC, @@ -80,6 +80,7 @@ "BLEUScore", "BootStrapper", "CalibrationError", + "CatMetric", "CohenKappa", "ConfusionMatrix", "CosineSimilarity", @@ -96,13 +97,16 @@ "KLDivergence", "LPIPS", "MatthewsCorrcoef", + "MaxMetric", "MeanAbsoluteError", "MeanAbsolutePercentageError", + "MeanMetric", "MeanSquaredError", "MeanSquaredLogError", "Metric", "MetricCollection", "MetricTracker", + "MinMetric", "PearsonCorrcoef", "PIT", "Precision", @@ -125,6 +129,7 @@ "Specificity", "SSIM", "StatScores", + "SumMetric", "SymmetricMeanAbsolutePercentageError", "WER", ] diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py new file mode 100644 index 00000000000..9b4060dd55c --- /dev/null +++ b/torchmetrics/aggregation.py @@ -0,0 +1,379 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union, List + +import torch +from torch import Tensor +import warnings +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + +class BaseAggregator(Metric): + """ Base class for aggregation metrics + + Args: + fn: string specifying the reduction function + default_value: default tensor value to use for the metric state + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + value: Tensor + + def __init__( + self, + fn: str, + default_value: Union[Tensor, List], + nan_strategy: Union[str, float] = 'error', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + allowed_nan_strategy = ('error', 'warn', 'ignore') + if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): + raise ValueError(f'Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}' + f' but got {nan_strategy}.') + + self.nan_strategy = nan_strategy + self.add_state("value", default=default_value, dist_reduce_fx=fn) + + def _cast_and_nan_check_input(self, x: float) -> Tensor: + if not isinstance(x, Tensor): + x = torch.as_tensor(x, dtype=torch.float32, device=self.device) + + nans = torch.isnan(x) + if any(nans.flatten()): + if self.nan_strategy == 'error': + raise RuntimeError('Encounted `nan` values in tensor') + elif self.nan_strategy == 'warn': + warnings.warn('Encounted `nan` values in tensor. Will be removed.', UserWarning) + x = x[~nans] + elif self.nan_strategy == 'ignore': + x = x[~nans] + else: + x[nans] = self.nan_strategy + + return x + + def compute(self) -> Tensor: + return self.value + + @property + def is_differentiable(self) -> bool: + return True + + +class MaxMetric(BaseAggregator): + """ Aggregate a stream of value into their maximum value + + Args: + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + def __init__( + self, + nan_strategy: Union[str, float] = 'warn', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + 'max', + -torch.tensor(float("inf")), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """ Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): # make sure tensor not empty + self.value = torch.maximum(self.value, torch.max(value)) + + +class MinMetric(BaseAggregator): + """ Aggregate a stream of value into their minimum value + + Args: + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + def __init__( + self, + nan_strategy: Union[str, float] = 'warn', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + 'min', + torch.tensor(float("inf")), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """ Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): # make sure tensor not empty + self.value = torch.minimum(self.value, torch.min(value)) + + +class SumMetric(BaseAggregator): + """ Aggregate a stream of value into their sum + + Args: + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + def __init__( + self, + nan_strategy: Union[str, float] = 'warn', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + 'sum', + torch.zeros(1), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """ Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + + """ + value = self._cast_and_nan_check_input(value) + self.value += value.sum() + + +class CatMetric(BaseAggregator): + """ Concatenate a stream of values + + Args: + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + def __init__( + self, + nan_strategy: Union[str, float] = 'warn', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__('cat', [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) + + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """ Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + + """ + value = self._cast_and_nan_check_input(value) + if any(value.flatten()): + self.value.append(value) + + def compute(self) -> Tensor: + return dim_zero_cat(self.value) if self.value else self.value + + +class MeanMetric(BaseAggregator): + """ Aggregate a stream of value into their mean value + + Args: + nan_strategy: + - ``'error'``: if any `nan` values are encounted will give a RuntimeError + - ``'warn'``: if any `nan` values are encounted will give a warning and continue + - ``'ignore'``: all `nan` values are silently removed + - a float: if a float is provided will impude any `nan` values with this value + + compute_on_step: + Forward only calls ``update()`` and returns None if this is + set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. + When `None`, DDP will be used to perform the allgather. + """ + def __init__( + self, + nan_strategy: Union[str, float] = 'warn', + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None + ): + super().__init__( + 'sum', + torch.zeros(1), + nan_strategy, + compute_on_step, + dist_sync_on_step, + process_group, + dist_sync_fn + ) + self.add_state("weight", default=torch.zeros(1), dist_reduce_fx="sum") + + def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: # type: ignore + """ Update state with data. + + Args: + value: Either a float or tensor containing data. Additional tensor + dimensions will be flattened + weight: Either a float or tensor containing weights for calculating + the average. Shape of weight should be able to broadcast with + the shape of `value`. Default to `1.0` corresponding to simple + harmonic average. + + """ + value = self._cast_and_nan_check_input(value) + weight = self._cast_and_nan_check_input(weight) + + # broadcast weight to values shape + if not hasattr(torch, "broadcast_to"): + if weight.shape == (): + weight = torch.ones_like(value) * weight + if weight.shape != value.shape: + raise ValueError("Broadcasting not supported on PyTorch <1.8") + else: + weight = torch.broadcast_to(weight, value.shape) + + self.value += (value * weight).sum() + self.weight += weight.sum() + + def compute(self) -> Tensor: + return self.value / self.weight diff --git a/torchmetrics/average.py b/torchmetrics/average.py deleted file mode 100644 index b602d57bbca..00000000000 --- a/torchmetrics/average.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional, Union - -import torch -from torch import Tensor - -from torchmetrics.metric import Metric - - -class AverageMeter(Metric): - """Computes the average of a stream of values. - - Forward accepts - - ``value`` (float tensor): ``(...)`` - - ``weight`` (float tensor): ``(...)`` - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is - set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. - default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. - When `None`, DDP will be used to perform the allgather. - - Example:: - >>> from torchmetrics import AverageMeter - >>> avg = AverageMeter() - >>> avg.update(3) - >>> avg.update(1) - >>> avg.compute() - tensor(2.) - - >>> avg = AverageMeter() - >>> values = torch.tensor([1., 2., 3.]) - >>> avg(values) - tensor(2.) - - >>> avg = AverageMeter() - >>> values = torch.tensor([1., 2.]) - >>> weights = torch.tensor([3., 1.]) - >>> avg(values, weights) - tensor(1.2500) - """ - - value: Tensor - weight: Tensor - - def __init__( - self, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None, - ) -> None: - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - self.add_state("value", torch.zeros(()), dist_reduce_fx="sum") - self.add_state("weight", torch.zeros(()), dist_reduce_fx="sum") - - # TODO: need to be strings because Unions are not pickleable in Python 3.6 - def update(self, value: "Union[Tensor, float]", weight: "Union[Tensor, float]" = 1.0) -> None: # type: ignore - """Updates the average with. - - Args: - value: A tensor of observations (can also be a scalar value) - weight: The weight of each observation (automatically broadcasted - to fit ``value``) - """ - if not isinstance(value, Tensor): - value = torch.as_tensor(value, dtype=torch.float32, device=self.value.device) - if not isinstance(weight, Tensor): - weight = torch.as_tensor(weight, dtype=torch.float32, device=self.weight.device) - - # braodcast_to only supported on PyTorch 1.8+ - if not hasattr(torch, "broadcast_to"): - if weight.shape == (): - weight = torch.ones_like(value) * weight - if weight.shape != value.shape: - raise ValueError("Broadcasting not supported on PyTorch <1.8") - else: - weight = torch.broadcast_to(weight, value.shape) - - self.value += (value * weight).sum() - self.weight += weight.sum() - - def compute(self) -> Tensor: - return self.value / self.weight diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 698ae799cef..49cf2ee01a5 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -25,7 +25,7 @@ from torch.nn import Module from torchmetrics.utilities import apply_to_collection, rank_zero_warn -from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum +from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum, dim_zero_max, dim_zero_min from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version @@ -123,8 +123,8 @@ def add_state( default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction + If value is ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"`` we will use ``torch.sum``, ``torch.mean``, + ``torch.cat``, ``torch.min`` and ``torch.max``` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. @@ -160,6 +160,10 @@ def add_state( dist_reduce_fx = dim_zero_sum elif dist_reduce_fx == "mean": dist_reduce_fx = dim_zero_mean + elif dist_reduce_fx == "max": + dist_reduce_fx = dim_zero_max + elif dist_reduce_fx == "min": + dist_reduce_fx = dim_zero_min elif dist_reduce_fx == "cat": dist_reduce_fx = dim_zero_cat elif dist_reduce_fx is not None and not callable(dist_reduce_fx): diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 46648352e8f..7ae9f7dfd51 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -37,6 +37,14 @@ def dim_zero_mean(x: Tensor) -> Tensor: return torch.mean(x, dim=0) +def dim_zero_max(x: Tensor) -> Tensor: + return torch.max(x, dim=0) + + +def dim_zero_min(x: Tensor) -> Tensor: + return torch.min(x, dim=0) + + def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist] From 32d2c42ed1e245c699b06cee18f215312fb8313f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Sep 2021 12:53:40 +0000 Subject: [PATCH 02/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_aggregation.py | 47 +++++++------- torchmetrics/__init__.py | 4 +- torchmetrics/aggregation.py | 107 +++++++++++++++----------------- torchmetrics/metric.py | 2 +- 4 files changed, 76 insertions(+), 84 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 3af8807911a..72c86d24ce0 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,10 +1,11 @@ from multiprocessing import Value + import numpy as np import pytest import torch from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.aggregation import MinMetric, MaxMetric, MeanMetric, SumMetric, CatMetric +from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric def compare_mean(values, weights): @@ -65,7 +66,7 @@ def update(self, values, weights): (WrappedSumMetric, compare_sum), (WrappedCatMetric, compare_cat), (MeanMetric, compare_mean), - ] + ], ) class TestAggregation(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) @@ -96,46 +97,44 @@ def test_aggregation_half_gpu(self, metric_class, compare_fn, values, weights): _case1 = float("nan") * torch.ones(5) _case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) + @pytest.mark.parametrize("value", [_case1, _case2]) -@pytest.mark.parametrize("nan_strategy", ['error', 'warn']) +@pytest.mark.parametrize("nan_strategy", ["error", "warn"]) @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_nan_error(value, nan_strategy, metric_class): metric = metric_class(nan_strategy=nan_strategy) - if nan_strategy == 'error': - with pytest.raises(RuntimeError, match='Encounted `nan` values in tensor'): + if nan_strategy == "error": + with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"): metric(value.clone()) - elif nan_strategy == 'warn': - with pytest.warns(UserWarning, match='Encounted `nan` values in tensor'): + elif nan_strategy == "warn": + with pytest.warns(UserWarning, match="Encounted `nan` values in tensor"): metric(value.clone()) -@pytest.mark.parametrize("metric_class, nan_strategy, value, expected", +@pytest.mark.parametrize( + "metric_class, nan_strategy, value, expected", [ - (MinMetric, 'ignore', _case1, torch.tensor(float("inf"))), + (MinMetric, "ignore", _case1, torch.tensor(float("inf"))), (MinMetric, 2.0, _case1, 2.0), - (MinMetric, 'ignore', _case2, 1.0), + (MinMetric, "ignore", _case2, 1.0), (MinMetric, 2.0, _case2, 1.0), - - (MaxMetric, 'ignore', _case1, -torch.tensor(float("inf"))), + (MaxMetric, "ignore", _case1, -torch.tensor(float("inf"))), (MaxMetric, 2.0, _case1, 2.0), - (MaxMetric, 'ignore', _case2, 5.0), + (MaxMetric, "ignore", _case2, 5.0), (MaxMetric, 2.0, _case2, 5.0), - - (SumMetric, 'ignore', _case1, 0.0), + (SumMetric, "ignore", _case1, 0.0), (SumMetric, 2.0, _case1, 10.0), - (SumMetric, 'ignore', _case2, 12.0), + (SumMetric, "ignore", _case2, 12.0), (SumMetric, 2.0, _case2, 14.0), - - (MeanMetric, 'ignore', _case1, torch.tensor([float("nan")])), + (MeanMetric, "ignore", _case1, torch.tensor([float("nan")])), (MeanMetric, 2.0, _case1, 2.0), - (MeanMetric, 'ignore', _case2, 3.0), + (MeanMetric, "ignore", _case2, 3.0), (MeanMetric, 2.0, _case2, 2.8), - - (CatMetric, 'ignore', _case1, []), + (CatMetric, "ignore", _case1, []), (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), - (CatMetric, 'ignore', _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), + (CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), - ] + ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): metric = metric_class(nan_strategy=nan_strategy) @@ -146,7 +145,7 @@ def test_nan_expected(metric_class, nan_strategy, value, expected): @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_error_on_wrong_nan_strategy(metric_class): - with pytest.raises(ValueError, match='Arg `nan_strategy` should either .*'): + with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"): metric_class(nan_strategy=[]) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index c9629f3f8ef..635a3a21377 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -2,7 +2,7 @@ import logging as __logging import os -from torchmetrics.__about__ import * # noqa: F401, F403 +from torchmetrics.__about__ import * # noqa: F403 _logger = __logging.getLogger("torchmetrics") _logger.addHandler(__logging.StreamHandler()) @@ -12,8 +12,8 @@ _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) from torchmetrics import functional # noqa: E402 +from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 from torchmetrics.audio import PIT, SI_SDR, SI_SNR, SNR # noqa: E402 -from torchmetrics.aggregation import CatMetric, MinMetric, MaxMetric, MeanMetric, SumMetric # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUC, AUROC, diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 9b4060dd55c..70a1c3259bc 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union, List +import warnings +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor -import warnings + from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat + class BaseAggregator(Metric): - """ Base class for aggregation metrics + """Base class for aggregation metrics. Args: fn: string specifying the reduction function @@ -44,17 +46,18 @@ class BaseAggregator(Metric): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + value: Tensor def __init__( self, fn: str, default_value: Union[Tensor, List], - nan_strategy: Union[str, float] = 'error', + nan_strategy: Union[str, float] = "error", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, @@ -62,10 +65,12 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - allowed_nan_strategy = ('error', 'warn', 'ignore') + allowed_nan_strategy = ("error", "warn", "ignore") if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): - raise ValueError(f'Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}' - f' but got {nan_strategy}.') + raise ValueError( + f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}" + f" but got {nan_strategy}." + ) self.nan_strategy = nan_strategy self.add_state("value", default=default_value, dist_reduce_fx=fn) @@ -76,12 +81,12 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: nans = torch.isnan(x) if any(nans.flatten()): - if self.nan_strategy == 'error': - raise RuntimeError('Encounted `nan` values in tensor') - elif self.nan_strategy == 'warn': - warnings.warn('Encounted `nan` values in tensor. Will be removed.', UserWarning) + if self.nan_strategy == "error": + raise RuntimeError("Encounted `nan` values in tensor") + elif self.nan_strategy == "warn": + warnings.warn("Encounted `nan` values in tensor. Will be removed.", UserWarning) x = x[~nans] - elif self.nan_strategy == 'ignore': + elif self.nan_strategy == "ignore": x = x[~nans] else: x[nans] = self.nan_strategy @@ -97,7 +102,7 @@ def is_differentiable(self) -> bool: class MaxMetric(BaseAggregator): - """ Aggregate a stream of value into their maximum value + """Aggregate a stream of value into their maximum value. Args: nan_strategy: @@ -119,31 +124,31 @@ class MaxMetric(BaseAggregator): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + def __init__( self, - nan_strategy: Union[str, float] = 'warn', + nan_strategy: Union[str, float] = "warn", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): super().__init__( - 'max', + "max", -torch.tensor(float("inf")), nan_strategy, compute_on_step, dist_sync_on_step, process_group, - dist_sync_fn + dist_sync_fn, ) def update(self, value: Union[float, Tensor]) -> None: # type: ignore - """ Update state with data. + """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened - """ value = self._cast_and_nan_check_input(value) if any(value.flatten()): # make sure tensor not empty @@ -151,7 +156,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore class MinMetric(BaseAggregator): - """ Aggregate a stream of value into their minimum value + """Aggregate a stream of value into their minimum value. Args: nan_strategy: @@ -173,31 +178,31 @@ class MinMetric(BaseAggregator): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + def __init__( self, - nan_strategy: Union[str, float] = 'warn', + nan_strategy: Union[str, float] = "warn", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): super().__init__( - 'min', + "min", torch.tensor(float("inf")), nan_strategy, compute_on_step, dist_sync_on_step, process_group, - dist_sync_fn + dist_sync_fn, ) def update(self, value: Union[float, Tensor]) -> None: # type: ignore - """ Update state with data. + """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened - """ value = self._cast_and_nan_check_input(value) if any(value.flatten()): # make sure tensor not empty @@ -205,7 +210,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore class SumMetric(BaseAggregator): - """ Aggregate a stream of value into their sum + """Aggregate a stream of value into their sum. Args: nan_strategy: @@ -227,38 +232,32 @@ class SumMetric(BaseAggregator): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + def __init__( self, - nan_strategy: Union[str, float] = 'warn', + nan_strategy: Union[str, float] = "warn", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): super().__init__( - 'sum', - torch.zeros(1), - nan_strategy, - compute_on_step, - dist_sync_on_step, - process_group, - dist_sync_fn + "sum", torch.zeros(1), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn ) def update(self, value: Union[float, Tensor]) -> None: # type: ignore - """ Update state with data. + """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened - """ value = self._cast_and_nan_check_input(value) self.value += value.sum() class CatMetric(BaseAggregator): - """ Concatenate a stream of values + """Concatenate a stream of values. Args: nan_strategy: @@ -280,23 +279,23 @@ class CatMetric(BaseAggregator): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + def __init__( self, - nan_strategy: Union[str, float] = 'warn', + nan_strategy: Union[str, float] = "warn", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): - super().__init__('cat', [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) + super().__init__("cat", [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) def update(self, value: Union[float, Tensor]) -> None: # type: ignore - """ Update state with data. + """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened - """ value = self._cast_and_nan_check_input(value) if any(value.flatten()): @@ -307,7 +306,7 @@ def compute(self) -> Tensor: class MeanMetric(BaseAggregator): - """ Aggregate a stream of value into their mean value + """Aggregate a stream of value into their mean value. Args: nan_strategy: @@ -329,27 +328,22 @@ class MeanMetric(BaseAggregator): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. """ + def __init__( self, - nan_strategy: Union[str, float] = 'warn', + nan_strategy: Union[str, float] = "warn", compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, - dist_sync_fn: Callable = None + dist_sync_fn: Callable = None, ): super().__init__( - 'sum', - torch.zeros(1), - nan_strategy, - compute_on_step, - dist_sync_on_step, - process_group, - dist_sync_fn + "sum", torch.zeros(1), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn ) self.add_state("weight", default=torch.zeros(1), dist_reduce_fx="sum") def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None: # type: ignore - """ Update state with data. + """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor @@ -358,7 +352,6 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 the average. Shape of weight should be able to broadcast with the shape of `value`. Default to `1.0` corresponding to simple harmonic average. - """ value = self._cast_and_nan_check_input(value) weight = self._cast_and_nan_check_input(weight) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 49cf2ee01a5..ff54eafbd75 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -25,7 +25,7 @@ from torch.nn import Module from torchmetrics.utilities import apply_to_collection, rank_zero_warn -from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum, dim_zero_max, dim_zero_min +from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_max, dim_zero_mean, dim_zero_min, dim_zero_sum from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version From 3f75b108efeb0fdd785d4f8e730a2ecc6bbc8919 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 14:56:38 +0200 Subject: [PATCH 03/26] changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a23a6128fc2..38b77b6c4fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) +- Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) @@ -29,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493)) +- Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506)) + ### Deprecated From 17fd6d1e22ff6a203d23ce0effacb5e09ef28742 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 14:59:43 +0200 Subject: [PATCH 04/26] pep8 --- tests/bases/test_aggregation.py | 2 -- torchmetrics/__init__.py | 3 +-- torchmetrics/metric.py | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 72c86d24ce0..889a39be05b 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,5 +1,3 @@ -from multiprocessing import Value - import numpy as np import pytest import torch diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 635a3a21377..3efcaf13bfa 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -2,7 +2,7 @@ import logging as __logging import os -from torchmetrics.__about__ import * # noqa: F403 +from torchmetrics.__about__ import * # noqa: F401, F403 _logger = __logging.getLogger("torchmetrics") _logger.addHandler(__logging.StreamHandler()) @@ -71,7 +71,6 @@ "Accuracy", "AUC", "AUROC", - "AverageMeter", "AveragePrecision", "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ff54eafbd75..72d5d1278c5 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -123,10 +123,10 @@ def add_state( default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"`` we will use ``torch.sum``, ``torch.mean``, - ``torch.cat``, ``torch.min`` and ``torch.max``` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. + If value is ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"`` we will use ``torch.sum``, + ``torch.mean``, ``torch.cat``, ``torch.min`` and ``torch.max``` respectively, each with argument + ``dim=0``. Note that the ``"cat"`` reduction only makes sense if the state is a list, and not + a tensor. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. Default is ``False``. From ce9621f2b02a540c3099959ee3198822175b0275 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 16:14:32 +0200 Subject: [PATCH 05/26] docs --- docs/source/references/modules.rst | 38 +++++++++++++++++++++++++++--- torchmetrics/aggregation.py | 3 +++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index ca891c6afd5..8c89625c14d 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -14,10 +14,42 @@ metrics. .. autoclass:: torchmetrics.Metric :noindex: -We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating -your own metric type might be too burdensome. -.. autoclass:: torchmetrics.AverageMeter +************************* +Basic Aggregation Metrics +************************* + +Torchmetrics comes with a number of metrics for aggregation of basic statistics: mean, max, min etc. of +either tensors or native python floats. + +CatMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.CatMetric + :noindex: + +MaxMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.MaxMetric + :noindex: + +MeanMetric +~~~~~~~~~~ + +.. autoclass:: torchmetrics.MeanMetric + :noindex: + +MinMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.MinMetric + :noindex: + +SumMetric +~~~~~~~~~ + +.. autoclass:: torchmetrics.SumMetric :noindex: ************* diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 70a1c3259bc..a307dd031e1 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -94,6 +94,7 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: return x def compute(self) -> Tensor: + """ Compute the aggregated value """ return self.value @property @@ -302,6 +303,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore self.value.append(value) def compute(self) -> Tensor: + """ Compute the aggregated value """ return dim_zero_cat(self.value) if self.value else self.value @@ -369,4 +371,5 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 self.weight += weight.sum() def compute(self) -> Tensor: + """ Compute the aggregated value """ return self.value / self.weight From 11d67807503fa07d7151462ae39719a3d383573d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Sep 2021 14:15:12 +0000 Subject: [PATCH 06/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/aggregation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index a307dd031e1..f3b4e9a5f2d 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -94,7 +94,7 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: return x def compute(self) -> Tensor: - """ Compute the aggregated value """ + """Compute the aggregated value.""" return self.value @property @@ -303,7 +303,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore self.value.append(value) def compute(self) -> Tensor: - """ Compute the aggregated value """ + """Compute the aggregated value.""" return dim_zero_cat(self.value) if self.value else self.value @@ -371,5 +371,5 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 self.weight += weight.sum() def compute(self) -> Tensor: - """ Compute the aggregated value """ + """Compute the aggregated value.""" return self.value / self.weight From 235735312801bd4d88149221295507eb7a89062a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Sep 2021 16:25:04 +0200 Subject: [PATCH 07/26] examples --- torchmetrics/aggregation.py | 70 +++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index a307dd031e1..bfe4a1c5dc0 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -45,6 +45,11 @@ class BaseAggregator(Metric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + """ value: Tensor @@ -124,6 +129,19 @@ class MaxMetric(BaseAggregator): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MaxMetric + >>> metric = MaxMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor(3.) + """ def __init__( @@ -178,6 +196,19 @@ class MinMetric(BaseAggregator): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MinMetric + >>> metric = MinMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor(1.) + """ def __init__( @@ -232,6 +263,19 @@ class SumMetric(BaseAggregator): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import SumMetric + >>> metric = SumMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor([6.]) + """ def __init__( @@ -279,6 +323,19 @@ class CatMetric(BaseAggregator): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import CatMetric + >>> metric = CatMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor([1., 2., 3.]) + """ def __init__( @@ -329,6 +386,19 @@ class MeanMetric(BaseAggregator): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float + + Example: + >>> from torchmetrics import MeanMetric + >>> metric = MeanMetric() + >>> metric.update(1) + >>> metric.update(torch.tensor([2, 3])) + >>> metric.compute() + tensor([2.]) + """ def __init__( From 03c1ddba9b034911db36296eefdfa027097c4aaa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Sep 2021 14:27:10 +0000 Subject: [PATCH 08/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/aggregation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 374a468878c..6327a0fc12f 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -49,7 +49,6 @@ class BaseAggregator(Metric): Raises: ValueError: If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float - """ value: Tensor @@ -141,7 +140,6 @@ class MaxMetric(BaseAggregator): >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor(3.) - """ def __init__( @@ -208,7 +206,6 @@ class MinMetric(BaseAggregator): >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor(1.) - """ def __init__( @@ -275,7 +272,6 @@ class SumMetric(BaseAggregator): >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor([6.]) - """ def __init__( @@ -335,7 +331,6 @@ class CatMetric(BaseAggregator): >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor([1., 2., 3.]) - """ def __init__( @@ -398,7 +393,6 @@ class MeanMetric(BaseAggregator): >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor([2.]) - """ def __init__( From f42dddbc9caf0dafa17d6f20736aebead2f1e8a4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Sep 2021 13:35:49 +0200 Subject: [PATCH 09/26] change max and min --- torchmetrics/aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 374a468878c..db868b707f7 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -171,7 +171,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore """ value = self._cast_and_nan_check_input(value) if any(value.flatten()): # make sure tensor not empty - self.value = torch.maximum(self.value, torch.max(value)) + self.value = torch.max(self.value, torch.max(value)) class MinMetric(BaseAggregator): @@ -238,7 +238,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore """ value = self._cast_and_nan_check_input(value) if any(value.flatten()): # make sure tensor not empty - self.value = torch.minimum(self.value, torch.min(value)) + self.value = torch.min(self.value, torch.min(value)) class SumMetric(BaseAggregator): From 950c91c86ce67b63baa9ecec1cf690cb911b7f7f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Sep 2021 14:24:09 +0200 Subject: [PATCH 10/26] mask gpu testing --- tests/bases/test_aggregation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 889a39be05b..e317ce8125c 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -88,6 +88,7 @@ def test_aggregation_half_cpu(self, metric_class, compare_fn, values, weights): pytest.skip("MinMetric and MaxMetric does not support half dtype on cpu") self.run_precision_test_cpu(preds=values, target=weights, metric_module=metric_class) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_aggregation_half_gpu(self, metric_class, compare_fn, values, weights): self.run_precision_test_gpu(preds=values, target=weights, metric_module=metric_class) From 8cf52a3ed5b3635af4ce1e6a2464ffa436dca81e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Sep 2021 15:38:45 +0200 Subject: [PATCH 11/26] remove half test --- tests/bases/test_aggregation.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index e317ce8125c..5fb462c122b 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -83,15 +83,6 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va def test_aggregation_differentiability(self, metric_class, compare_fn, values, weights): self.run_differentiability_test(preds=values, target=weights, metric_module=metric_class) - def test_aggregation_half_cpu(self, metric_class, compare_fn, values, weights): - if metric_class in (WrappedMinMetric, WrappedMaxMetric): - pytest.skip("MinMetric and MaxMetric does not support half dtype on cpu") - self.run_precision_test_cpu(preds=values, target=weights, metric_module=metric_class) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_aggregation_half_gpu(self, metric_class, compare_fn, values, weights): - self.run_precision_test_gpu(preds=values, target=weights, metric_module=metric_class) - _case1 = float("nan") * torch.ones(5) _case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) From 444727205fb0fce174421bb67783e3f6f2eab751 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Sep 2021 10:47:19 +0200 Subject: [PATCH 12/26] fix tests --- tests/bases/test_aggregation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 5fb462c122b..7dbcca77309 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -23,7 +23,7 @@ def compare_max(values, weights): def compare_cat(values, weights): - return np.concatenate(values.numpy()) + return values.numpy() # wrap all other than mean metric to take an additional argument @@ -75,6 +75,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va dist_sync_on_step=dist_sync_on_step, metric_class=metric_class, sk_metric=compare_fn, + check_scriptable=False, # Abuse of names here preds=values, target=weights, From 76067681ff685b46c8b66571fb7198de0ab5f555 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Sep 2021 10:48:38 +0200 Subject: [PATCH 13/26] aggr --- tests/bases/test_aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 7dbcca77309..d2542011a71 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -75,7 +75,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va dist_sync_on_step=dist_sync_on_step, metric_class=metric_class, sk_metric=compare_fn, - check_scriptable=False, + check_scriptable=False if metric_class == WrappedCatMetric else True, # Abuse of names here preds=values, target=weights, From 59936ac934bf943b9ca18b59e3d3acfcb07ab06f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Sep 2021 13:09:51 +0200 Subject: [PATCH 14/26] fix --- torchmetrics/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index df49bc4ba7d..31e9724c102 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -95,7 +95,7 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: else: x[nans] = self.nan_strategy - return x + return x.float() def compute(self) -> Tensor: """Compute the aggregated value.""" From 7a44155fcb6b1afbd480bd9f7b04ea029c13ceac Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 30 Sep 2021 11:01:20 +0200 Subject: [PATCH 15/26] fix test --- tests/bases/test_aggregation.py | 13 +++++-------- torchmetrics/aggregation.py | 9 +++++++-- torchmetrics/utilities/data.py | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index d2542011a71..c29c3f83845 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,8 +1,10 @@ +from functools import partial + import numpy as np import pytest import torch -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_PROCESSES, MetricTester from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric @@ -22,10 +24,6 @@ def compare_max(values, weights): return np.max(values.numpy()) -def compare_cat(values, weights): - return values.numpy() - - # wrap all other than mean metric to take an additional argument # this lets them fit into the testing framework class WrappedMinMetric(MinMetric): @@ -62,20 +60,19 @@ def update(self, values, weights): (WrappedMinMetric, compare_min), (WrappedMaxMetric, compare_max), (WrappedSumMetric, compare_sum), - (WrappedCatMetric, compare_cat), (MeanMetric, compare_mean), ], ) class TestAggregation(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): self.run_class_metric_test( ddp=ddp, dist_sync_on_step=dist_sync_on_step, metric_class=metric_class, sk_metric=compare_fn, - check_scriptable=False if metric_class == WrappedCatMetric else True, + check_scriptable=True, # Abuse of names here preds=values, target=weights, diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 31e9724c102..93c5c5f6fcc 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -99,7 +99,7 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: def compute(self) -> Tensor: """Compute the aggregated value.""" - return self.value + return self.value.squeeze() if isinstance(self.value, Tensor) else self.value @property def is_differentiable(self) -> bool: @@ -356,7 +356,12 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore def compute(self) -> Tensor: """Compute the aggregated value.""" - return dim_zero_cat(self.value) if self.value else self.value + if isinstance(self.value, list) and self.value: + return dim_zero_cat(self.value) + else: + return self.value + #print(self.value) + #return dim_zero_cat(self.value) if self.value else self.value class MeanMetric(BaseAggregator): diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 7ae9f7dfd51..51840795613 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -38,11 +38,11 @@ def dim_zero_mean(x: Tensor) -> Tensor: def dim_zero_max(x: Tensor) -> Tensor: - return torch.max(x, dim=0) + return torch.max(x, dim=0).values def dim_zero_min(x: Tensor) -> Tensor: - return torch.min(x, dim=0) + return torch.min(x, dim=0).values def _flatten(x: Sequence) -> list: From 0eaade7a6e61949e3af9970dacaf745038845ea4 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 30 Sep 2021 11:05:18 +0200 Subject: [PATCH 16/26] fix attribute --- torchmetrics/aggregation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 93c5c5f6fcc..4999af41ffc 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -52,6 +52,7 @@ class BaseAggregator(Metric): """ value: Tensor + is_differentiable = True def __init__( self, @@ -101,10 +102,6 @@ def compute(self) -> Tensor: """Compute the aggregated value.""" return self.value.squeeze() if isinstance(self.value, Tensor) else self.value - @property - def is_differentiable(self) -> bool: - return True - class MaxMetric(BaseAggregator): """Aggregate a stream of value into their maximum value. From 3cff7f81ecd9acaf21a4bd7e6227d3ed41ed09fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Sep 2021 09:09:57 +0000 Subject: [PATCH 17/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 93c5c5f6fcc..9ad9876fc2f 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -360,8 +360,8 @@ def compute(self) -> Tensor: return dim_zero_cat(self.value) else: return self.value - #print(self.value) - #return dim_zero_cat(self.value) if self.value else self.value + # print(self.value) + # return dim_zero_cat(self.value) if self.value else self.value class MeanMetric(BaseAggregator): From a25a0ec825550fe9b240fd18a04171b1cf7de015 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 30 Sep 2021 11:06:52 +0200 Subject: [PATCH 18/26] remove --- tests/bases/test_aggregation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index c29c3f83845..92d90fcd279 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -1,10 +1,8 @@ -from functools import partial - import numpy as np import pytest import torch -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_PROCESSES, MetricTester +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric From 84ab17f32a47e01cf5298a8aafce19f6c3674406 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 30 Sep 2021 11:23:09 +0200 Subject: [PATCH 19/26] add docstrings --- tests/bases/test_aggregation.py | 14 ++++++++++++++ torchmetrics/aggregation.py | 12 ++++++++---- torchmetrics/utilities/data.py | 9 +++++---- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 92d90fcd279..abda1d432fa 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -7,18 +7,22 @@ def compare_mean(values, weights): + """ reference implementation for mean aggregation """ return np.average(values.numpy(), weights=weights) def compare_sum(values, weights): + """ reference implementation for sum aggregation """ return np.sum(values.numpy()) def compare_min(values, weights): + """ reference implementation for min aggregation """ return np.min(values.numpy()) def compare_max(values, weights): + """ reference implementation for max aggregation """ return np.max(values.numpy()) @@ -26,21 +30,25 @@ def compare_max(values, weights): # this lets them fit into the testing framework class WrappedMinMetric(MinMetric): def update(self, values, weights): + """ only pass values on """ super().update(values) class WrappedMaxMetric(MaxMetric): def update(self, values, weights): + """ only pass values on """ super().update(values) class WrappedSumMetric(SumMetric): def update(self, values, weights): + """ only pass values on """ super().update(values) class WrappedCatMetric(CatMetric): def update(self, values, weights): + """ only pass values on """ super().update(values) @@ -65,6 +73,7 @@ class TestAggregation(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): + """ test modular implementation """ self.run_class_metric_test( ddp=ddp, dist_sync_on_step=dist_sync_on_step, @@ -77,6 +86,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va ) def test_aggregation_differentiability(self, metric_class, compare_fn, values, weights): + """ test functional implementation """ self.run_differentiability_test(preds=values, target=weights, metric_module=metric_class) @@ -88,6 +98,7 @@ def test_aggregation_differentiability(self, metric_class, compare_fn, values, w @pytest.mark.parametrize("nan_strategy", ["error", "warn"]) @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_nan_error(value, nan_strategy, metric_class): + """ test correct errors are raised """ metric = metric_class(nan_strategy=nan_strategy) if nan_strategy == "error": with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"): @@ -123,6 +134,7 @@ def test_nan_error(value, nan_strategy, metric_class): ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): + """ test that nan values are handled correctly """ metric = metric_class(nan_strategy=nan_strategy) metric.update(value.clone()) out = metric.compute() @@ -131,6 +143,7 @@ def test_nan_expected(metric_class, nan_strategy, value, expected): @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_error_on_wrong_nan_strategy(metric_class): + """ test error raised on wrong nan_strategy argument """ with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"): metric_class(nan_strategy=[]) @@ -140,6 +153,7 @@ def test_error_on_wrong_nan_strategy(metric_class): "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] ) def test_mean_metric_broadcasting(weights, expected): + """ check that weight broadcasting works for mean metric """ values = torch.arange(24).reshape(2, 3, 4) avg = MeanMetric() diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 147077a8e6d..80792386f29 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -80,7 +80,10 @@ def __init__( self.nan_strategy = nan_strategy self.add_state("value", default=default_value, dist_reduce_fx=fn) - def _cast_and_nan_check_input(self, x: float) -> Tensor: + def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: + """ Converts input x to a tensor if not already and afterwards + checks for nans that either give an error, warning or just ignored + """ if not isinstance(x, Tensor): x = torch.as_tensor(x, dtype=torch.float32, device=self.device) @@ -88,7 +91,7 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: if any(nans.flatten()): if self.nan_strategy == "error": raise RuntimeError("Encounted `nan` values in tensor") - elif self.nan_strategy == "warn": + if self.nan_strategy == "warn": warnings.warn("Encounted `nan` values in tensor. Will be removed.", UserWarning) x = x[~nans] elif self.nan_strategy == "ignore": @@ -98,6 +101,9 @@ def _cast_and_nan_check_input(self, x: float) -> Tensor: return x.float() + def update(self, value: Union[float, Tensor]) -> None: # type: ignore + pass + def compute(self) -> Tensor: """Compute the aggregated value.""" return self.value.squeeze() if isinstance(self.value, Tensor) else self.value @@ -357,8 +363,6 @@ def compute(self) -> Tensor: return dim_zero_cat(self.value) else: return self.value - # print(self.value) - # return dim_zero_cat(self.value) if self.value else self.value class MeanMetric(BaseAggregator): diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 51840795613..122eb97fe99 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -22,6 +22,7 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: + """ concatenation along the zero dimension """ x = x if isinstance(x, (list, tuple)) else [x] x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] if not x: # empty list @@ -30,25 +31,25 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: def dim_zero_sum(x: Tensor) -> Tensor: + """ summation along the zero dimension """ return torch.sum(x, dim=0) def dim_zero_mean(x: Tensor) -> Tensor: + """ average along the zero dimension """ return torch.mean(x, dim=0) def dim_zero_max(x: Tensor) -> Tensor: + """ max along the zero dimension """ return torch.max(x, dim=0).values def dim_zero_min(x: Tensor) -> Tensor: + """ min along the zero dimension """ return torch.min(x, dim=0).values -def _flatten(x: Sequence) -> list: - return [item for sublist in x for item in sublist] - - def to_onehot( label_tensor: Tensor, num_classes: Optional[int] = None, From a8d2cce057aa0b3ceaadc52536cd825f8318569a Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 30 Sep 2021 11:29:49 +0200 Subject: [PATCH 20/26] fix mistake --- torchmetrics/utilities/data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 122eb97fe99..3b960c0d991 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -50,6 +50,10 @@ def dim_zero_min(x: Tensor) -> Tensor: return torch.min(x, dim=0).values +def _flatten(x: Sequence) -> list: + return [item for sublist in x for item in sublist] + + def to_onehot( label_tensor: Tensor, num_classes: Optional[int] = None, From eba0aa60e1a6e4787f72a8cc0b1d20dc0954abbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Sep 2021 09:31:44 +0000 Subject: [PATCH 21/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_aggregation.py | 28 ++++++++++++++-------------- torchmetrics/aggregation.py | 5 ++--- torchmetrics/utilities/data.py | 10 +++++----- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index abda1d432fa..974fd447d96 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -7,22 +7,22 @@ def compare_mean(values, weights): - """ reference implementation for mean aggregation """ + """reference implementation for mean aggregation.""" return np.average(values.numpy(), weights=weights) def compare_sum(values, weights): - """ reference implementation for sum aggregation """ + """reference implementation for sum aggregation.""" return np.sum(values.numpy()) def compare_min(values, weights): - """ reference implementation for min aggregation """ + """reference implementation for min aggregation.""" return np.min(values.numpy()) def compare_max(values, weights): - """ reference implementation for max aggregation """ + """reference implementation for max aggregation.""" return np.max(values.numpy()) @@ -30,25 +30,25 @@ def compare_max(values, weights): # this lets them fit into the testing framework class WrappedMinMetric(MinMetric): def update(self, values, weights): - """ only pass values on """ + """only pass values on.""" super().update(values) class WrappedMaxMetric(MaxMetric): def update(self, values, weights): - """ only pass values on """ + """only pass values on.""" super().update(values) class WrappedSumMetric(SumMetric): def update(self, values, weights): - """ only pass values on """ + """only pass values on.""" super().update(values) class WrappedCatMetric(CatMetric): def update(self, values, weights): - """ only pass values on """ + """only pass values on.""" super().update(values) @@ -73,7 +73,7 @@ class TestAggregation(MetricTester): @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): - """ test modular implementation """ + """test modular implementation.""" self.run_class_metric_test( ddp=ddp, dist_sync_on_step=dist_sync_on_step, @@ -86,7 +86,7 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va ) def test_aggregation_differentiability(self, metric_class, compare_fn, values, weights): - """ test functional implementation """ + """test functional implementation.""" self.run_differentiability_test(preds=values, target=weights, metric_module=metric_class) @@ -98,7 +98,7 @@ def test_aggregation_differentiability(self, metric_class, compare_fn, values, w @pytest.mark.parametrize("nan_strategy", ["error", "warn"]) @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_nan_error(value, nan_strategy, metric_class): - """ test correct errors are raised """ + """test correct errors are raised.""" metric = metric_class(nan_strategy=nan_strategy) if nan_strategy == "error": with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"): @@ -134,7 +134,7 @@ def test_nan_error(value, nan_strategy, metric_class): ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): - """ test that nan values are handled correctly """ + """test that nan values are handled correctly.""" metric = metric_class(nan_strategy=nan_strategy) metric.update(value.clone()) out = metric.compute() @@ -143,7 +143,7 @@ def test_nan_expected(metric_class, nan_strategy, value, expected): @pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric]) def test_error_on_wrong_nan_strategy(metric_class): - """ test error raised on wrong nan_strategy argument """ + """test error raised on wrong nan_strategy argument.""" with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"): metric_class(nan_strategy=[]) @@ -153,7 +153,7 @@ def test_error_on_wrong_nan_strategy(metric_class): "weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)] ) def test_mean_metric_broadcasting(weights, expected): - """ check that weight broadcasting works for mean metric """ + """check that weight broadcasting works for mean metric.""" values = torch.arange(24).reshape(2, 3, 4) avg = MeanMetric() diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 80792386f29..bcfeea2af0d 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -81,9 +81,8 @@ def __init__( self.add_state("value", default=default_value, dist_reduce_fx=fn) def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: - """ Converts input x to a tensor if not already and afterwards - checks for nans that either give an error, warning or just ignored - """ + """Converts input x to a tensor if not already and afterwards checks for nans that either give an error, + warning or just ignored.""" if not isinstance(x, Tensor): x = torch.as_tensor(x, dtype=torch.float32, device=self.device) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 122eb97fe99..c6f50ae1ebd 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -22,7 +22,7 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: - """ concatenation along the zero dimension """ + """concatenation along the zero dimension.""" x = x if isinstance(x, (list, tuple)) else [x] x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] if not x: # empty list @@ -31,22 +31,22 @@ def dim_zero_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: def dim_zero_sum(x: Tensor) -> Tensor: - """ summation along the zero dimension """ + """summation along the zero dimension.""" return torch.sum(x, dim=0) def dim_zero_mean(x: Tensor) -> Tensor: - """ average along the zero dimension """ + """average along the zero dimension.""" return torch.mean(x, dim=0) def dim_zero_max(x: Tensor) -> Tensor: - """ max along the zero dimension """ + """max along the zero dimension.""" return torch.max(x, dim=0).values def dim_zero_min(x: Tensor) -> Tensor: - """ min along the zero dimension """ + """min along the zero dimension.""" return torch.min(x, dim=0).values From 0ad2b410ccaf4e49221608bf4dc6e2341a446180 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 30 Sep 2021 12:00:47 +0200 Subject: [PATCH 22/26] Update torchmetrics/aggregation.py --- torchmetrics/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index bcfeea2af0d..fea43258e9c 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -273,7 +273,7 @@ class SumMetric(BaseAggregator): >>> metric.update(1) >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() - tensor([6.]) + tensor(6.) """ def __init__( From bc891a975b641ccfdb7573b0e9d0362afb12897f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 7 Oct 2021 09:55:13 +0200 Subject: [PATCH 23/26] suggestions --- torchmetrics/aggregation.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index fea43258e9c..53f27d1260d 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -27,7 +27,7 @@ class BaseAggregator(Metric): Args: fn: string specifying the reduction function default_value: default tensor value to use for the metric state - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed @@ -52,11 +52,12 @@ class BaseAggregator(Metric): """ value: Tensor - is_differentiable = True + is_differentiable = None + higher_is_better = None def __init__( self, - fn: str, + fn: Union[Callable, str], default_value: Union[Tensor, List], nan_strategy: Union[str, float] = "error", compute_on_step: bool = True, @@ -112,7 +113,7 @@ class MaxMetric(BaseAggregator): """Aggregate a stream of value into their maximum value. Args: - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed @@ -178,7 +179,7 @@ class MinMetric(BaseAggregator): """Aggregate a stream of value into their minimum value. Args: - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed @@ -244,7 +245,7 @@ class SumMetric(BaseAggregator): """Aggregate a stream of value into their sum. Args: - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed @@ -303,7 +304,7 @@ class CatMetric(BaseAggregator): """Concatenate a stream of values. Args: - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed @@ -368,7 +369,7 @@ class MeanMetric(BaseAggregator): """Aggregate a stream of value into their mean value. Args: - nan_strategy: + nan_strategy: options: - ``'error'``: if any `nan` values are encounted will give a RuntimeError - ``'warn'``: if any `nan` values are encounted will give a warning and continue - ``'ignore'``: all `nan` values are silently removed From f302e11471b8fdbe65151a613e44fb673d59787a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 8 Oct 2021 10:28:12 +0200 Subject: [PATCH 24/26] diff test --- tests/bases/test_aggregation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 974fd447d96..174afe12514 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -85,10 +85,6 @@ def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, va target=weights, ) - def test_aggregation_differentiability(self, metric_class, compare_fn, values, weights): - """test functional implementation.""" - self.run_differentiability_test(preds=values, target=weights, metric_module=metric_class) - _case1 = float("nan") * torch.ones(5) _case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0]) From 57dccee020f69030950f723682e5b9bc1399da78 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 8 Oct 2021 10:38:09 +0200 Subject: [PATCH 25/26] docs --- tests/bases/test_aggregation.py | 5 +++++ torchmetrics/aggregation.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 174afe12514..072c9652452 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -29,24 +29,28 @@ def compare_max(values, weights): # wrap all other than mean metric to take an additional argument # this lets them fit into the testing framework class WrappedMinMetric(MinMetric): + """ Wrapped min metric""" def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedMaxMetric(MaxMetric): + """ Wrapped max metric""" def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedSumMetric(SumMetric): + """ Wrapped min metric""" def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedCatMetric(CatMetric): + """ Wrapped cat metric""" def update(self, values, weights): """only pass values on.""" super().update(values) @@ -70,6 +74,7 @@ def update(self, values, weights): ], ) class TestAggregation(MetricTester): + """ Test aggregation metrics""" @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 53f27d1260d..d1975d4a261 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -102,6 +102,7 @@ def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: return x.float() def update(self, value: Union[float, Tensor]) -> None: # type: ignore + """ Overwrite in child class """ pass def compute(self) -> Tensor: @@ -361,8 +362,7 @@ def compute(self) -> Tensor: """Compute the aggregated value.""" if isinstance(self.value, list) and self.value: return dim_zero_cat(self.value) - else: - return self.value + return self.value class MeanMetric(BaseAggregator): From 1c74b95fba9c16e154de8ed188b70c2bfe69a713 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Oct 2021 08:38:59 +0000 Subject: [PATCH 26/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_aggregation.py | 15 ++++++++++----- torchmetrics/aggregation.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 072c9652452..106621e9cb4 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -29,28 +29,32 @@ def compare_max(values, weights): # wrap all other than mean metric to take an additional argument # this lets them fit into the testing framework class WrappedMinMetric(MinMetric): - """ Wrapped min metric""" + """Wrapped min metric.""" + def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedMaxMetric(MaxMetric): - """ Wrapped max metric""" + """Wrapped max metric.""" + def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedSumMetric(SumMetric): - """ Wrapped min metric""" + """Wrapped min metric.""" + def update(self, values, weights): """only pass values on.""" super().update(values) class WrappedCatMetric(CatMetric): - """ Wrapped cat metric""" + """Wrapped cat metric.""" + def update(self, values, weights): """only pass values on.""" super().update(values) @@ -74,7 +78,8 @@ def update(self, values, weights): ], ) class TestAggregation(MetricTester): - """ Test aggregation metrics""" + """Test aggregation metrics.""" + @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False]) def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights): diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index d1975d4a261..e009abfec2b 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -102,7 +102,7 @@ def _cast_and_nan_check_input(self, x: Union[float, Tensor]) -> Tensor: return x.float() def update(self, value: Union[float, Tensor]) -> None: # type: ignore - """ Overwrite in child class """ + """Overwrite in child class.""" pass def compute(self) -> Tensor: