diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ec6ffbeaa3..1377394d14f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366)) +- Fixed plotting of confusion matrices ([#2358](https://github.com/Lightning-AI/torchmetrics/pull/2358)) + + ## [1.3.0] - 2024-01-10 ### Added diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 476be4bd0e8..feb906810bc 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -242,15 +242,17 @@ def plot_confusion_matrix( fig_label = None labels = labels or np.arange(n_classes).tolist() - fig, axs = plt.subplots(nrows=rows, ncols=cols) if ax is None else (ax.get_figure(), ax) + fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): ax = axs[i] if rows != 1 and cols != 1 else axs if fig_label is not None: ax.set_title(f"Label {fig_label[i]}", fontsize=15) ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach()) - ax.set_xlabel("Predicted class", fontsize=15) - ax.set_ylabel("True class", fontsize=15) + if i // cols == rows - 1: # bottom row only + ax.set_xlabel("Predicted class", fontsize=15) + if i % cols == 0: # leftmost column only + ax.set_ylabel("True class", fontsize=15) ax.set_xticks(list(range(n_classes))) ax.set_yticks(list(range(n_classes))) ax.set_xticklabels(labels, rotation=45, fontsize=10) @@ -259,7 +261,7 @@ def plot_confusion_matrix( if add_text: for ii, jj in product(range(n_classes), range(n_classes)): val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj] - ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15) + ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15) return fig, axs