Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ConfusionMatrix and StatScores for num_classes > 16 #1521

Merged
merged 12 commits into from
Feb 23, 2023

Conversation

vincentvaroquauxads
Copy link
Contributor

@vincentvaroquauxads vincentvaroquauxads commented Feb 17, 2023

e.g. if preds or target is uint8 and num_classes > 16, unique_mapping overflows

What does this PR do?

Fixes #1510

I would suggest to define :
unittests.helpers.NUM_CLASSES=17
and add something like the following to unittests.classification.inputs._multiclass_cases

    pytest.param(
        Input(
            preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1),
            target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.uint8),
        ),
        id="input[single dim uint8-logits]",
    ),

But it breaks a lot of tests...

e.g. if preds or target is uint8 and num_classes > 16, unique_mapping
overflows
@justusschock
Copy link
Member

@vincentvaroquauxads thanks for checking it out and submitting this PR. Would you be interested in enabling the change with num_classes=17 and fixing the issues arising from this? It should almost always be dtype issues.

@codecov
Copy link

codecov bot commented Feb 17, 2023

Codecov Report

Merging #1521 (0803189) into master (a9def1c) will decrease coverage by 36%.
The diff coverage is 100%.

Additional details and impacted files
@@           Coverage Diff            @@
##           master   #1521     +/-   ##
========================================
- Coverage      87%     51%    -36%     
========================================
  Files         216     216             
  Lines       11334   11334             
========================================
- Hits         9845    5752   -4093     
- Misses       1489    5582   +4093     

vincentvaroquauxads and others added 2 commits February 17, 2023 17:51
…le dim int8-logits"

NUM_CLASSES=17, add _multiclass_cases =
    pytest.param(
        Input(
preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE,
NUM_CLASSES), -1),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES,
BATCH_SIZE), dtype=torch.int8),
        ),
        id="input[single dim int8-logits]",
    )
@vincentvaroquauxads
Copy link
Contributor Author

I investigated a little bit more and some tests fail because ignore_index=-1 and target.dtype=uint8.
I changed target.dtype to torch.int8, but I still have issues...

Actually, after hours, I finally realized that, using the original master branch, tests are broken in my environment.
It might be linked with this warning:
sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use zero_division parameter to control this behavior. (scikit-learn 1.0.2)

Sorry, no time to investigate further... And it's another issue (no link with this PR)...

I pushed the unit tests, if anyone can try them on a clean environment.

@Borda Borda changed the title fix: ConfusionMatrix&StatScores for num_classes > 16 fix ConfusionMatrix and StatScores for num_classes > 16 Feb 20, 2023
@Borda Borda changed the title fix ConfusionMatrix and StatScores for num_classes > 16 fix ConfusionMatrix and StatScores for num_classes > 16 Feb 20, 2023
@SkafteNicki SkafteNicki added the bug / fix Something isn't working label Feb 23, 2023
@SkafteNicki SkafteNicki added this to the v0.12 milestone Feb 23, 2023
@mergify mergify bot added the ready label Feb 23, 2023
@SkafteNicki SkafteNicki merged commit 6bc249d into Lightning-AI:master Feb 23, 2023
Borda pushed a commit that referenced this pull request Feb 27, 2023
* fix: ConfusionMatrix&StatScores for num_classes > 16

e.g. if preds or target is uint8 and num_classes > 16, unique_mapping
overflows

* unittest #1521, NUM_CLASSES=17, add multiclass case "single dim int8-logits"

* revert tests

* add byte testing

* changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>

(cherry picked from commit 6bc249d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MulticlassConfusionMatrix overflows for num_classes > 16
4 participants