Skip to content

Commit

Permalink
Fix plotting of metric collection when prefix/postfix is set (#2429)
Browse files Browse the repository at this point in the history
* implementation
* add tests
* changelog

(cherry picked from commit 0a82679)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 16, 2024
1 parent 03e08bd commit c347715
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))


- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))


- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,11 @@ def plot(
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
)

val = val or self.compute()
if together:
return plot_single_or_multi_val(val, ax=ax)
fig_axs = []
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
if isinstance(val, dict):
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
elif isinstance(val, Sequence):
Expand Down
7 changes: 6 additions & 1 deletion tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,12 +855,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label

@pytest.mark.parametrize("together", [True, False])
@pytest.mark.parametrize("num_vals", [1, 2])
def test_plot_method_collection(together, num_vals):
@pytest.mark.parametrize(
("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")]
)
def test_plot_method_collection(together, num_vals, prefix, postfix):
"""Test the plot method of metric collection."""
m_collection = MetricCollection(
BinaryAccuracy(),
BinaryPrecision(),
BinaryRecall(),
prefix=prefix,
postfix=postfix,
)
if num_vals == 1:
m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))
Expand Down

0 comments on commit c347715

Please sign in to comment.