-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix ConfusionMatrix
and StatScores
for num_classes > 16
#1521
Conversation
e.g. if preds or target is uint8 and num_classes > 16, unique_mapping overflows
@vincentvaroquauxads thanks for checking it out and submitting this PR. Would you be interested in enabling the change with num_classes=17 and fixing the issues arising from this? It should almost always be dtype issues. |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #1521 +/- ##
========================================
- Coverage 87% 51% -36%
========================================
Files 216 216
Lines 11334 11334
========================================
- Hits 9845 5752 -4093
- Misses 1489 5582 +4093 |
…le dim int8-logits" NUM_CLASSES=17, add _multiclass_cases = pytest.param( Input( preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.int8), ), id="input[single dim int8-logits]", )
for more information, see https://pre-commit.ci
I investigated a little bit more and some tests fail because Actually, after hours, I finally realized that, using the original master branch, tests are broken in my environment. Sorry, no time to investigate further... And it's another issue (no link with this PR)... I pushed the unit tests, if anyone can try them on a clean environment. |
ConfusionMatrix
and StatScore
s for num_classes > 16
ConfusionMatrix
and StatScore
s for num_classes > 16ConfusionMatrix
and StatScores
for num_classes > 16
* fix: ConfusionMatrix&StatScores for num_classes > 16 e.g. if preds or target is uint8 and num_classes > 16, unique_mapping overflows * unittest #1521, NUM_CLASSES=17, add multiclass case "single dim int8-logits" * revert tests * add byte testing * changelog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 6bc249d)
e.g. if preds or target is uint8 and num_classes > 16, unique_mapping overflows
What does this PR do?
Fixes #1510
I would suggest to define :
unittests.helpers.NUM_CLASSES=17
and add something like the following to
unittests.classification.inputs._multiclass_cases
But it breaks a lot of tests...