diff --git a/CHANGELOG.md b/CHANGELOG.md index d31eac47771..0f7ae27b4b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858)) + + - Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840)) diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 4d14349b7f9..d5f8f373c7b 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -270,7 +270,7 @@ def plot_confusion_matrix( fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): - ax = axs[i] if rows != 1 and cols != 1 else axs + ax = axs[i] if (rows != 1 or cols != 1) else axs if fig_label is not None: ax.set_title(f"Label {fig_label[i]}", fontsize=15) im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 23777dbdc2a..7d7a5f28cb0 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -393,6 +393,16 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype): dtype=dtype, ) + @pytest.mark.parametrize("num_labels", [2, NUM_CLASSES]) + def test_multilabel_confusion_matrix_plot(self, num_labels, inputs): + """Test multilabel cm plots.""" + multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=num_labels) + preds = target = torch.ones(1, num_labels).int() + multi_label_confusion_matrix.update(preds, target) + fig, ax = multi_label_confusion_matrix.plot() + assert fig is not None + assert ax is not None + def test_warning_on_nan(): """Test that a warning is given if division by zero happens during normalization of confusion matrix."""