Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero_division option to the precision, recall, f1, fbeta. #2198

Merged
merged 40 commits into from
May 3, 2024

Conversation

i-aki-y
Copy link
Contributor

@i-aki-y i-aki-y commented Nov 2, 2023

What does this PR do?

I want to add zero_division option to the precision, recall, f1, fbeta metrics as well as the sklearn counterparts.
cf, https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn-metrics-precision-score

The zero_division is important when we use samplewise metrics (multidim_average="samplewise") where some samples have no positive targets.

The following example shows that the preds1 and preds2 have the same f1-scores (0.3333).
But the pred2 perfectly matches the target while the pred1 does not.
This means that we could not distinguish the two models: The model can correctly predict no positive sample from the model that returns many false positives.
We can fix this by setting the zero_division=1.
For this example, the preds2 would become (f1=1.0) while the preds1 is still (f1=0.3333).

import torch
from torchmetrics.functional.classification import f1_score

targets = torch.tensor([
    [0, 0, 0, 0],  # sample1
    [0, 0, 0, 0],  # sample2
    [0, 0, 1, 1],  # sample3
])

preds1 = torch.tensor([
    [0, 1, 1, 1],
    [1, 0, 1, 1],
    [0, 0, 1, 1],
])

preds2 = torch.tensor([
    [0, 0, 0, 0],
    [0, 0, 0, 0],
    [0, 0, 1, 1],
])

## default behavior
scores1 = f1_score(preds1, targets, task="binary", multidim_average="samplewise")
print(scores1, scores1.mean())
#=> tensor([0., 0., 1.]) tensor(0.3333)

scores2 = f1_score(preds2, targets, task="binary", multidim_average="samplewise")
print(scores2, scores2.mean())
#=> tensor([0., 0., 1.]) tensor(0.3333)

## If zero_division = 1
scores1 = f1_score(preds1, targets, task="binary", multidim_average="samplewise", zero_division=1)
print(scores1, scores1.mean())
#=> tensor([0., 0., 1.]) tensor(0.3333)

scores2 = f1_score(preds2, targets, task="binary", multidim_average="samplewise", zero_division=1)
print(scores2, scores2.mean())
#=> tensor([1., 1., 1.]) tensor(1.)

Note:
The latest sklearn (ver 1.3) has a bug in f1_score when zero_division=1.
scikit-learn/scikit-learn#27577

So, some test cases that compare the results with the sklearn will fail until the bug is fixed.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2198.org.readthedocs.build/en/2198/

@SkafteNicki
Copy link
Member

Hi @i-aki-y,
I am okay with this change but would I think it would make sense base StatScores class and then use Precision, Recall, FScore etc.

@i-aki-y
Copy link
Contributor Author

i-aki-y commented Nov 17, 2023

@SkafteNicki Thank you for your comment.

I tried refactoring StatScores to have zero_division.
And removed __init__ and *_precision_recall_score_arg_validation methods, which I added, from the Precision and Recall.

I avoid declaring zero_division argument in StatScores' constructor because StatScores (and some other sub-classes) does not use zero_division.
So, the zero_division argument is passed through the kwargs and is poped before the supre().__init__(**kwargs) code.

ex.

class BinaryStatScores(_AbstractStatScores):
        ...
    def __init__(
        self,
        threshold: float = 0.5,
        multidim_average: Literal["global", "samplewise"] = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs: Any,
    ) -> None:
        zero_division = kwargs.pop("zero_division", 0)
        super(_AbstractStatScores, self).__init__(**kwargs)
        if validate_args:
            _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)

@Borda
Copy link
Member

Borda commented Dec 18, 2023

@i-aki-y seems that the results are different then expected...

@i-aki-y
Copy link
Contributor Author

i-aki-y commented Dec 19, 2023

@Borda Thanks

I found some bugs and fixed them.

Since the PR (scikit-learn/scikit-learn#27577) was merged, I confirmed that the fix passes related test cases (ex. pytest tests/unittests/classification) using the dev version of sklearn==1.4.dev0 and pytorch==2.0.1 in my local machine.

@Borda
Copy link
Member

Borda commented Jan 9, 2024

@i-aki-y could you pls have a look at all the failing cases, some with the wrong value...
turning it to draft till the tests are resolved, pls make it ready when test are mostly green :)

@Borda Borda marked this pull request as draft January 9, 2024 12:44
@i-aki-y i-aki-y changed the title Add zero_division option to the precision, recall, f1, fbeta. [WIP] Add zero_division option to the precision, recall, f1, fbeta. Jan 10, 2024
@i-aki-y
Copy link
Contributor Author

i-aki-y commented Jan 10, 2024

@Borda OK, I put the [WIP] in this PR title.

As mentioned above, the current sklearn (1.3.2) has a bug that mishandles zero_division.
I think it will cause the CI errors.
Fortunately, the bugfix has been merged recently, so I expect the next sklearn's version (1.3.3 or 1.4?) will fix these problems.

@mergify mergify bot removed the has conflicts label Mar 16, 2024
@robmarkcole
Copy link

Appears this has gone cold? Keen to see support for zero_division elsewhere too, particularly JaccardIndex

@lantiga
Copy link
Contributor

lantiga commented Mar 29, 2024

@SkafteNicki let's revive this

@i-aki-y
Copy link
Contributor Author

i-aki-y commented Apr 2, 2024

@robmarkcole The jaccard index seems to be implemented by using _safe_devide in the _jaccard_index_reduce, so I think a similar fix is possible.

@SkafteNicki
Copy link
Member

@robmarkcole and @i-aki-y I added support in jaccard index now.
I make sure to get the remaining test passing so we can finally land this PR either today or tomorrow and then have a release afterwards.

@SkafteNicki SkafteNicki added the Priority Critical task/issue label May 2, 2024
@SkafteNicki
Copy link
Member

The PR has been failing for the sensitivity_at_specificity metric for python 3.8 for some time now. Reason being that a bugfix only present in scikit-learn>=1.3.0 for their implementation was needed to match our implementation but this PR makes the required version of scikit-learn<1.3 for python 3.8, so that was the reason those test were failing.
They are now being skipped for old sklearn versions because that is the only solution where we can land this PR.

@mergify mergify bot added the ready label May 2, 2024
@SkafteNicki SkafteNicki merged commit 335ebe6 into Lightning-AI:master May 3, 2024
70 checks passed
baskrahmer pushed a commit to baskrahmer/torchmetrics that referenced this pull request May 13, 2024
…tning-AI#2198)

* Add support of zero_division parameter

* fix overlooked

* Fix type error

* Fix type error

* Fix missing comma

* Doc fix wrong math expression

* Fixed StatScores to have zero_division

* fix missing zero_division arg

* fix device mismatch

* use scikit-learn 1.4.0

* fix scikit-learn min ver

* fix for new sklearn version

* fix scikit-learn requirements

* fix incorrect requirements condition

* fix test code to pass in multiple sklearn versions

* changelog

* better docstring

* add jaccardindex

* fix tests

* skip for old sklearn versions

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants