diff --git a/CHANGELOG.md b/CHANGELOG.md index a8ccaa59b15..bd5d2c13ee5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379)) +- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429)) + + - Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423)) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 06b9b7b4c4e..d6ad1287c58 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -647,12 +647,11 @@ def plot( f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the " f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`" ) - val = val or self.compute() if together: return plot_single_or_multi_val(val, ax=ax) fig_axs = [] - for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)): + for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)): if isinstance(val, dict): f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax) elif isinstance(val, Sequence): diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 736c162ace8..be48c8a3df7 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -855,12 +855,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label @pytest.mark.parametrize("together", [True, False]) @pytest.mark.parametrize("num_vals", [1, 2]) -def test_plot_method_collection(together, num_vals): +@pytest.mark.parametrize( + ("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")] +) +def test_plot_method_collection(together, num_vals, prefix, postfix): """Test the plot method of metric collection.""" m_collection = MetricCollection( BinaryAccuracy(), BinaryPrecision(), BinaryRecall(), + prefix=prefix, + postfix=postfix, ) if num_vals == 1: m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))