From 4c8e6858bf7713effe2166d0008a92b6611003ab Mon Sep 17 00:00:00 2001 From: asteyo <49311556+asteyo@users.noreply.github.com> Date: Mon, 27 Sep 2021 08:26:15 +0300 Subject: [PATCH] R2 score metric (#1274) * r2_score added * catalyst-make-codestyle _r2_score.py * r2 score LoaderMetric API is added * r2 score renamed to r2 squared * functional r2 metric name fix to r2_squared * test for functional r2 squared is added * compute key-value fix * args order in update fixed * args order fix * r2squared import is added to functional metrics init * r2squared callback is added * r2squared callback is added to metrics callbacks init * r2squared metric is added to metrics init * tests for r2squared is added * regression test update * metrics docs update * codestyle fix * torch.square to torch.pow fix) * codestyle update * spaces codestyle fix * codestyle fix * Update _r2_squared.py Co-authored-by: Sergey Kolesnikov --- catalyst/callbacks/metrics/__init__.py | 2 + catalyst/callbacks/metrics/r2_squared.py | 75 +++++++++++++++++ catalyst/metrics/__init__.py | 1 + catalyst/metrics/_r2_squared.py | 64 ++++++++++++++ catalyst/metrics/functional/__init__.py | 1 + catalyst/metrics/functional/_r2_squared.py | 50 +++++++++++ docs/api/metrics.rst | 14 ++++ .../metrics/functional/test_r2_squared.py | 16 ++++ tests/catalyst/metrics/test_r2squared.py | 83 +++++++++++++++++++ tests/pipelines/test_regression.py | 1 + 10 files changed, 307 insertions(+) create mode 100644 catalyst/callbacks/metrics/r2_squared.py create mode 100644 catalyst/metrics/_r2_squared.py create mode 100644 catalyst/metrics/functional/_r2_squared.py create mode 100644 tests/catalyst/metrics/functional/test_r2_squared.py create mode 100644 tests/catalyst/metrics/test_r2squared.py diff --git a/catalyst/callbacks/metrics/__init__.py b/catalyst/callbacks/metrics/__init__.py index b9331bea67..65a30bc1e3 100644 --- a/catalyst/callbacks/metrics/__init__.py +++ b/catalyst/callbacks/metrics/__init__.py @@ -17,6 +17,8 @@ from catalyst.callbacks.metrics.functional_metric import FunctionalMetricCallback +from catalyst.callbacks.metrics.r2_squared import R2SquaredCallback + from catalyst.callbacks.metrics.recsys import ( HitrateCallback, MAPCallback, diff --git a/catalyst/callbacks/metrics/r2_squared.py b/catalyst/callbacks/metrics/r2_squared.py new file mode 100644 index 0000000000..939aaae4d7 --- /dev/null +++ b/catalyst/callbacks/metrics/r2_squared.py @@ -0,0 +1,75 @@ +from catalyst.callbacks.metric import LoaderMetricCallback +from catalyst.metrics._r2_squared import R2Squared + + +class R2SquaredCallback(LoaderMetricCallback): + """R2 Squared metric callback. + + Args: + input_key: input key to use for r2squared calculation, specifies our ``y_true``. + target_key: output key to use for r2squared calculation, specifies our ``y_pred``. + prefix: metric prefix + suffix: metric suffix + + Examples: + + .. code-block:: python + + import torch + from torch.utils.data import DataLoader, TensorDataset + from catalyst import dl + + # data + num_samples, num_features = int(1e4), int(1e1) + X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) + dataset = TensorDataset(X, y) + loader = DataLoader(dataset, batch_size=32, num_workers=1) + loaders = {"train": loader, "valid": loader} + + # model, criterion, optimizer, scheduler + model = torch.nn.Linear(num_features, 1) + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters()) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) + + # model training + runner = dl.SupervisedRunner() + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + loaders=loaders, + logdir="./logdir", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + num_epochs=8, + verbose=True, + callbacks=[ + dl.R2SquaredCallback(input_key="logits", target_key="targets") + ] + ) + + .. note:: + Please follow the `minimal examples`_ sections for more use cases. + + .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples + """ + + def __init__( + self, + input_key: str, + target_key: str, + prefix: str = None, + suffix: str = None, + ): + """Init.""" + super().__init__( + metric=R2Squared(prefix=prefix, suffix=suffix), + input_key=input_key, + target_key=target_key, + ) + + +__all__ = ["R2SquaredCallback"] diff --git a/catalyst/metrics/__init__.py b/catalyst/metrics/__init__.py index 94bd17d190..5c11dbf435 100644 --- a/catalyst/metrics/__init__.py +++ b/catalyst/metrics/__init__.py @@ -28,6 +28,7 @@ from catalyst.metrics._map import MAPMetric from catalyst.metrics._mrr import MRRMetric from catalyst.metrics._ndcg import NDCGMetric +from catalyst.metrics._r2_squared import R2Squared from catalyst.metrics._segmentation import ( RegionBasedMetric, IOUMetric, diff --git a/catalyst/metrics/_r2_squared.py b/catalyst/metrics/_r2_squared.py new file mode 100644 index 0000000000..144afc3b8e --- /dev/null +++ b/catalyst/metrics/_r2_squared.py @@ -0,0 +1,64 @@ +from typing import Optional + +import torch + +from catalyst.metrics._metric import ICallbackLoaderMetric + + +class R2Squared(ICallbackLoaderMetric): + """This metric accumulates r2 score along loader + + Args: + compute_on_call: if True, allows compute metric's value on call + prefix: metric prefix + suffix: metric suffix + """ + + def __init__( + self, + compute_on_call: bool = True, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + ) -> None: + """Init R2Squared""" + super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) + self.metric_name = f"{self.prefix}r2squared{self.suffix}" + self.num_examples = 0 + self.delta_sum = 0 + self.y_sum = 0 + self.y_sq_sum = 0 + + def reset(self, num_batches: int, num_samples: int) -> None: + """ + Reset metrics fields + """ + self.num_examples = 0 + self.delta_sum = 0 + self.y_sum = 0 + self.y_sq_sum = 0 + + def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None: + """ + Update accumulated data with new batch + """ + self.num_examples += len(y_true) + self.delta_sum += torch.sum(torch.pow(y_pred - y_true, 2)) + self.y_sum += torch.sum(y_true) + self.y_sq_sum += torch.sum(torch.pow(y_true, 2)) + + def compute(self) -> torch.Tensor: + """ + Return accumulated metric + """ + return 1 - self.delta_sum / (self.y_sq_sum - (self.y_sum ** 2) / self.num_examples) + + def compute_key_value(self) -> torch.Tensor: + """ + Return key-value + """ + r2squared = self.compute() + output = {self.metric_name: r2squared} + return output + + +__all__ = ["R2Squared"] diff --git a/catalyst/metrics/functional/__init__.py b/catalyst/metrics/functional/__init__.py index 99d6a2fb19..e384f87c99 100644 --- a/catalyst/metrics/functional/__init__.py +++ b/catalyst/metrics/functional/__init__.py @@ -27,6 +27,7 @@ from catalyst.metrics.functional._mrr import reciprocal_rank, mrr from catalyst.metrics.functional._ndcg import dcg, ndcg from catalyst.metrics.functional._precision import precision +from catalyst.metrics.functional._r2_squared import r2_squared from catalyst.metrics.functional._recall import recall from catalyst.metrics.functional._segmentation import ( iou, diff --git a/catalyst/metrics/functional/_r2_squared.py b/catalyst/metrics/functional/_r2_squared.py new file mode 100644 index 0000000000..51d53ce9ed --- /dev/null +++ b/catalyst/metrics/functional/_r2_squared.py @@ -0,0 +1,50 @@ +from typing import Sequence + +import torch + + +def r2_squared(outputs: torch.Tensor, targets: torch.Tensor) -> Sequence[torch.Tensor]: + """ + Computes regression r2 squared. + + Args: + outputs: model outputs + with shape [bs; 1] + targets: ground truth + with shape [bs; 1] + + Returns: + float of computed r2 squared + + Examples: + + .. code-block:: python + + import torch + from catalyst import metrics + metrics.r2_squared( + outputs=torch.tensor([0, 1, 2]), + targets=torch.tensor([0, 1, 2]), + ) + # tensor([1.]) + + + .. code-block:: python + + import torch + from catalyst import metrics + metrics.r2_squared( + outputs=torch.tensor([2.5, 0.0, 2, 8]), + targets=torch.tensor([3, -0.5, 2, 7]), + ) + # tensor([0.9486]) + """ + total_sum_of_squares = torch.sum( + torch.pow(targets.float() - torch.mean(targets.float()), 2) + ).view(-1) + residual_sum_of_squares = torch.sum(torch.pow(targets.float() - outputs.float(), 2)).view(-1) + output = 1 - residual_sum_of_squares / total_sum_of_squares + return output + + +__all__ = ["r2_squared"] diff --git a/docs/api/metrics.rst b/docs/api/metrics.rst index ed1bfdc358..c98d5fb230 100644 --- a/docs/api/metrics.rst +++ b/docs/api/metrics.rst @@ -163,6 +163,13 @@ RecSys – NDCGMetric :undoc-members: :show-inheritance: +Regression – R2Squared +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: catalyst.metrics._r2_squared.R2Squared + :exclude-members: __init__ + :undoc-members: + :show-inheritance: + Segmentation – RegionBasedMetric ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: catalyst.metrics._segmentation.RegionBasedMetric @@ -272,6 +279,13 @@ Precision :undoc-members: :show-inheritance: +R2Squared +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: catalyst.metrics.functional._r2_squared + :members: + :undoc-members: + :show-inheritance: + Recall ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: catalyst.metrics.functional._recall diff --git a/tests/catalyst/metrics/functional/test_r2_squared.py b/tests/catalyst/metrics/functional/test_r2_squared.py new file mode 100644 index 0000000000..a59ec1f892 --- /dev/null +++ b/tests/catalyst/metrics/functional/test_r2_squared.py @@ -0,0 +1,16 @@ +# flake8: noqa +import numpy as np + +import torch + +from catalyst.metrics.functional._r2_squared import r2_squared + + +def test_r2_squared(): + """ + Tests for catalyst.metrics.r2_squared metric. + """ + y_true = torch.tensor([3, -0.5, 2, 7]) + y_pred = torch.tensor([2.5, 0.0, 2, 8]) + val = r2_squared(y_pred, y_true) + assert torch.isclose(val, torch.Tensor([0.9486])) diff --git a/tests/catalyst/metrics/test_r2squared.py b/tests/catalyst/metrics/test_r2squared.py new file mode 100644 index 0000000000..fd7e02a0e2 --- /dev/null +++ b/tests/catalyst/metrics/test_r2squared.py @@ -0,0 +1,83 @@ +# flake8: noqa +from typing import Dict, Iterable, Union + +import pytest + +import torch + +from catalyst.metrics._r2_squared import R2Squared + + +@pytest.mark.parametrize( + "outputs,targets,true_values", + ( + ( + torch.Tensor([2.5, 0.0, 2, 8]), + torch.Tensor([3, -0.5, 2, 7]), + { + "r2squared": torch.Tensor([0.9486]), + }, + ), + ), +) +def test_r2_squared( + outputs: torch.Tensor, + targets: torch.Tensor, + true_values: Dict[str, torch.Tensor], +) -> None: + """ + Test r2 squared metric + + Args: + outputs: tensor of outputs + targets: tensor of targets + true_values: true metric values + """ + metric = R2Squared() + metric.update(y_pred=outputs, y_true=targets) + metrics = metric.compute_key_value() + for key in true_values.keys(): + assert torch.isclose(true_values[key], metrics[key]) + + +@pytest.mark.parametrize( + "outputs_list,targets_list,true_values", + ( + ( + ( + torch.Tensor([2.5, 0.0, 2, 8]), + torch.Tensor([2.5, 0.0, 2, 8]), + torch.Tensor([2.5, 0.0, 2, 8]), + torch.Tensor([2.5, 0.0, 2, 8]), + ), + ( + torch.Tensor([3, -0.5, 2, 7]), + torch.Tensor([3, -0.5, 2, 7]), + torch.Tensor([3, -0.5, 2, 7]), + torch.Tensor([3, -0.5, 2, 7]), + ), + { + "r2squared": torch.Tensor([0.9486]), + }, + ), + ), +) +def test_r2_squared_update( + outputs_list: Iterable[torch.Tensor], + targets_list: Iterable[torch.Tensor], + true_values: Dict[str, torch.Tensor], +): + """ + Test r2 squared metric computation + + Args: + outputs_list: list of outputs + targets_list: list of targets + true_values: true metric values + """ + metric = R2Squared() + for outputs, targets in zip(outputs_list, targets_list): + metric.update(y_pred=outputs, y_true=targets) + metrics = metric.compute_key_value() + for key in true_values.keys(): + assert torch.isclose(true_values[key], metrics[key]) diff --git a/tests/pipelines/test_regression.py b/tests/pipelines/test_regression.py index 6bc95496c6..d0615a44c0 100644 --- a/tests/pipelines/test_regression.py +++ b/tests/pipelines/test_regression.py @@ -41,6 +41,7 @@ def train_experiment(device, engine=None): minimize_valid_metric=True, num_epochs=1, verbose=False, + callbacks=[dl.R2SquaredCallback(input_key="logits", target_key="targets")], )