Skip to content

Commit

Permalink
Improve confusion matrix plotting (#2358)
Browse files Browse the repository at this point in the history
Round floats to avoid floating point errors leading to UI overflow.
Remove overlapping text in multilabel plots
by reducing redundant `Predicted class` and `True class` labels.
Use `constrained_layout` to prevent some text from being cut off.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent c94e21a commit 71089f0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))


- Fixed plotting of confusion matrices ([#2358](https://github.com/Lightning-AI/torchmetrics/pull/2358))

---

## [1.3.0] - 2024-01-10
Expand Down
10 changes: 6 additions & 4 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,17 @@ def plot_confusion_matrix(
fig_label = None
labels = labels or np.arange(n_classes).tolist()

fig, axs = plt.subplots(nrows=rows, ncols=cols) if ax is None else (ax.get_figure(), ax)
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
axs = trim_axs(axs, nb)
for i in range(nb):
ax = axs[i] if rows != 1 and cols != 1 else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
ax.set_xlabel("Predicted class", fontsize=15)
ax.set_ylabel("True class", fontsize=15)
if i // cols == rows - 1: # bottom row only
ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0: # leftmost column only
ax.set_ylabel("True class", fontsize=15)
ax.set_xticks(list(range(n_classes)))
ax.set_yticks(list(range(n_classes)))
ax.set_xticklabels(labels, rotation=45, fontsize=10)
Expand All @@ -259,7 +261,7 @@ def plot_confusion_matrix(
if add_text:
for ii, jj in product(range(n_classes), range(n_classes)):
val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15)
ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15)

return fig, axs

Expand Down

0 comments on commit 71089f0

Please sign in to comment.