Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve confusion matrix plotting #2358

Merged
merged 4 commits into from
Feb 12, 2024
Merged

Conversation

JonasVerbickas
Copy link
Contributor

@JonasVerbickas JonasVerbickas commented Feb 6, 2024

What does this PR do?

Current confusion matrix plotting tends to produce unusable results.
When I tried to create a normalized multilabel confusion matrix plot this cluttered mess was produced:
image

  1. When normalizing converting tensor to float using val.item() produces numbers with too many decimal places.
  2. Redundant "True class" and "Predicted class" use up a lot of space and make the graph unreadable
  3. Longer labels overlap

Round floats to avoid floating point errors leading to UI overflow.

Rounding to two decimal places seems reasonable since it's difficult to fit more digits into multilabel confusion matrices.

ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15)

Rounding the val tensor itself does not work here. The issue is with the item() method converting incorrectly.

Changing this line converts this plot:
image

It might be worth considering using less aggressive rounding for simpler binary confusion matrices since they have room to display more digits without any overlap.

Removing redundant "True class" and "Predicted class"

Code reducing the number of times x and y labels are shown:

if i // cols == rows-1:  # bottom row only
    ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0:  # leftmost column only
    ax.set_ylabel("True class", fontsize=15)

Produces a much cleaner plot without sacrificing any information:
image

Reduce overlap

By utilizing constrained_layout=True:

fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)

Reduces overlap between labels horizontally and vertically from {0, 1} ticks:
image

Even though this does not fix the top row, the longest names in this example are extreme.
Ensuring proper middle-row separation is good enough for now.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2358.org.readthedocs.build/en/2358/

Round floats to avoid floating point errors leading to UI overflow.
Remove overlapping text in multilabel plots
by reducing redundant `Predicted class` and `True class` labels.
Use `constrained_layout` to prevent some text from being cut off.
@Borda
Copy link
Member

Borda commented Feb 6, 2024

yes, the readability for binary cases is bad, just thinking about making the long labels as multi-line or smaller font? 🤔

@JonasVerbickas
Copy link
Contributor Author

If I have the time, I will look into multiline labels this week.
That said, I think this commit is a good first step towards confusion matrix plot clarity.

Copy link

codecov bot commented Feb 6, 2024

Codecov Report

Merging #2358 (1337bcd) into master (b187bfd) will decrease coverage by 0%.
Report is 1 commits behind head on master.
The diff coverage is 0%.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2358   +/-   ##
======================================
- Coverage      69%     69%   -0%     
======================================
  Files         303     303           
  Lines       17058   17060    +2     
======================================
  Hits        11760   11760           
- Misses       5298    5300    +2     

@SkafteNicki SkafteNicki added the enhancement New feature or request label Feb 12, 2024
@SkafteNicki SkafteNicki added this to the v1.3.x milestone Feb 12, 2024
@SkafteNicki
Copy link
Member

@JonasVerbickas thanks for the contribution. We are going to merge this PR such that it become part of the v1.3.1release. Feel free to send a new PR with further improvements to the plotting capabilities of torchmetrics.

@Borda Borda merged commit 71089f0 into Lightning-AI:master Feb 12, 2024
46 of 54 checks passed
Borda pushed a commit that referenced this pull request Feb 12, 2024
Round floats to avoid floating point errors leading to UI overflow.
Remove overlapping text in multilabel plots
by reducing redundant `Predicted class` and `True class` labels.
Use `constrained_layout` to prevent some text from being cut off.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

(cherry picked from commit 71089f0)
@mergify mergify bot added the ready label Feb 12, 2024
Borda pushed a commit that referenced this pull request Feb 12, 2024
Round floats to avoid floating point errors leading to UI overflow.
Remove overlapping text in multilabel plots
by reducing redundant `Predicted class` and `True class` labels.
Use `constrained_layout` to prevent some text from being cut off.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

(cherry picked from commit 71089f0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants