Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2 score metric #1274

Merged
merged 27 commits into from
Sep 27, 2021
Merged
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56c953f
r2_score added
asteyo Aug 6, 2021
767430c
catalyst-make-codestyle _r2_score.py
asteyo Aug 9, 2021
faadc56
Merge branch 'master' into r2_score
asteyo Aug 9, 2021
85d8da6
r2 score LoaderMetric API is added
asteyo Aug 30, 2021
622c01e
r2 score renamed to r2 squared
asteyo Sep 11, 2021
29da20f
functional r2 metric name fix to r2_squared
asteyo Sep 11, 2021
62d65f9
test for functional r2 squared is added
asteyo Sep 11, 2021
9991a0c
compute key-value fix
asteyo Sep 20, 2021
ea0d905
args order in update fixed
asteyo Sep 20, 2021
3dabd39
args order fix
asteyo Sep 20, 2021
c4622ac
r2squared import is added to functional metrics init
asteyo Sep 20, 2021
84b0435
r2squared callback is added
asteyo Sep 20, 2021
a4a4de1
r2squared callback is added to metrics callbacks init
asteyo Sep 20, 2021
80be323
r2squared metric is added to metrics init
asteyo Sep 20, 2021
9d03acb
tests for r2squared is added
asteyo Sep 20, 2021
eb343f3
regression test update
asteyo Sep 20, 2021
2759cda
metrics docs update
asteyo Sep 20, 2021
2824561
Merge branch 'master' into r2_score
asteyo Sep 21, 2021
5f9713f
codestyle fix
asteyo Sep 25, 2021
932816b
Merge branch 'master' into r2_score
asteyo Sep 25, 2021
7656b3b
Merge branch 'master' of https://github.com/catalyst-team/catalyst in…
asteyo Sep 25, 2021
8abe5ed
torch.square to torch.pow fix)
asteyo Sep 25, 2021
f75a2ef
Merge branch 'r2_score' of https://github.com/asteyo/catalyst into r2…
asteyo Sep 25, 2021
e7d7623
codestyle update
asteyo Sep 26, 2021
8693c6b
spaces codestyle fix
asteyo Sep 26, 2021
72fb072
codestyle fix
asteyo Sep 26, 2021
b60060b
Update _r2_squared.py
Scitator Sep 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions catalyst/metrics/functional/_r2_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Sequence

import torch


def r2_score(outputs: torch.Tensor, targets: torch.Tensor,) -> Sequence[torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb it should be done like sklearn implementation https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html ?
I mean, it seems cool to have possibility to calculate r2 score of tensors with shape [bs, num_outputs] right in torch
plus, sample_weight might be useful

"""
Computes regression r2 score.
Args:
outputs: model outputs
with shape [bs; 1]
targets: ground truth
with shape [bs; 1]
Returns:
float of computed r2 score
Examples:
.. code-block:: python
import torch
from catalyst import metrics
metrics.r2_score(
outputs=torch.tensor([0, 1, 2]),
targets=torch.tensor([0, 1, 2]),
)
# tensor([1.])
.. code-block:: python
import torch
from catalyst import metrics
metrics.r2_score(
outputs=torch.tensor([2.5, 0.0, 2, 8]),
targets=torch.tensor([3, -0.5, 2, 7]),
)
# tensor([0.9486])
"""
total_sum_of_squares = torch.sum(
torch.square(targets.float() - torch.mean(targets.float()))
).view(-1)
residual_sum_of_squares = torch.sum(torch.square(targets.float() - outputs.float())).view(-1)
output = 1 - residual_sum_of_squares / total_sum_of_squares
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe, you should do something like max(total_sum_of_squares, eps) to avoid zero division

return output


__all__ = ["r2_score"]