-
-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
R2 score metric #1274
Merged
Merged
R2 score metric #1274
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
56c953f
r2_score added
asteyo 767430c
catalyst-make-codestyle _r2_score.py
asteyo faadc56
Merge branch 'master' into r2_score
asteyo 85d8da6
r2 score LoaderMetric API is added
asteyo 622c01e
r2 score renamed to r2 squared
asteyo 29da20f
functional r2 metric name fix to r2_squared
asteyo 62d65f9
test for functional r2 squared is added
asteyo 9991a0c
compute key-value fix
asteyo ea0d905
args order in update fixed
asteyo 3dabd39
args order fix
asteyo c4622ac
r2squared import is added to functional metrics init
asteyo 84b0435
r2squared callback is added
asteyo a4a4de1
r2squared callback is added to metrics callbacks init
asteyo 80be323
r2squared metric is added to metrics init
asteyo 9d03acb
tests for r2squared is added
asteyo eb343f3
regression test update
asteyo 2759cda
metrics docs update
asteyo 2824561
Merge branch 'master' into r2_score
asteyo 5f9713f
codestyle fix
asteyo 932816b
Merge branch 'master' into r2_score
asteyo 7656b3b
Merge branch 'master' of https://github.com/catalyst-team/catalyst in…
asteyo 8abe5ed
torch.square to torch.pow fix)
asteyo f75a2ef
Merge branch 'r2_score' of https://github.com/asteyo/catalyst into r2…
asteyo e7d7623
codestyle update
asteyo 8693c6b
spaces codestyle fix
asteyo 72fb072
codestyle fix
asteyo b60060b
Update _r2_squared.py
Scitator File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Sequence | ||
|
||
import torch | ||
|
||
|
||
def r2_score(outputs: torch.Tensor, targets: torch.Tensor,) -> Sequence[torch.Tensor]: | ||
""" | ||
Computes regression r2 score. | ||
Args: | ||
outputs: model outputs | ||
with shape [bs; 1] | ||
targets: ground truth | ||
with shape [bs; 1] | ||
Returns: | ||
float of computed r2 score | ||
Examples: | ||
.. code-block:: python | ||
import torch | ||
from catalyst import metrics | ||
metrics.r2_score( | ||
outputs=torch.tensor([0, 1, 2]), | ||
targets=torch.tensor([0, 1, 2]), | ||
) | ||
# tensor([1.]) | ||
.. code-block:: python | ||
import torch | ||
from catalyst import metrics | ||
metrics.r2_score( | ||
outputs=torch.tensor([2.5, 0.0, 2, 8]), | ||
targets=torch.tensor([3, -0.5, 2, 7]), | ||
) | ||
# tensor([0.9486]) | ||
""" | ||
total_sum_of_squares = torch.sum( | ||
torch.square(targets.float() - torch.mean(targets.float())) | ||
).view(-1) | ||
residual_sum_of_squares = torch.sum(torch.square(targets.float() - outputs.float())).view(-1) | ||
output = 1 - residual_sum_of_squares / total_sum_of_squares | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe, you should do something like |
||
return output | ||
|
||
|
||
__all__ = ["r2_score"] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb it should be done like sklearn implementation https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html ?
I mean, it seems cool to have possibility to calculate r2 score of tensors with shape [bs, num_outputs] right in torch
plus,
sample_weight
might be useful