Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix plotting of metric collection when prefix/postfix is set #2429

Merged
merged 22 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a9e8a59
implementation
SkafteNicki Mar 5, 2024
77c84ae
add tests
SkafteNicki Mar 5, 2024
17a510b
changelog
SkafteNicki Mar 5, 2024
196957f
Merge branch 'master' into bugfix/plot_metriccollection
Borda Mar 6, 2024
b92edb4
Merge branch 'master' into bugfix/plot_metriccollection
Borda Mar 6, 2024
3e221af
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 7, 2024
fb5a62f
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 7, 2024
1045f63
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 13, 2024
775fe85
Merge branch 'master' into bugfix/plot_metriccollection
Borda Mar 14, 2024
06d564c
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 14, 2024
07524f2
ci/gpu: do not fail if HF cache is not present
Borda Mar 14, 2024
4832720
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 14, 2024
ef0cc27
ci/gpu: do not update ref on PR
Borda Mar 14, 2024
9fb488d
build(deps): update fire requirement from <=0.5.0 to <=0.6.0 in /requ…
dependabot[bot] Mar 14, 2024
92804ef
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 14, 2024
0240110
build(deps): bump pytest-timeout from 2.2.0 to 2.3.1 in /requirements…
dependabot[bot] Mar 14, 2024
758e719
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 14, 2024
090e3ed
ci/mergify: rename label `ready`
Borda Mar 14, 2024
58270f6
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 14, 2024
398b6de
Merge branch 'master' into bugfix/plot_metriccollection
Borda Mar 14, 2024
c896125
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 15, 2024
bef1d3d
Merge branch 'master' into bugfix/plot_metriccollection
mergify[bot] Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))


- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))


- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,11 @@ def plot(
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
)

val = val or self.compute()
if together:
return plot_single_or_multi_val(val, ax=ax)
fig_axs = []
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
if isinstance(val, dict):
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
elif isinstance(val, Sequence):
Expand Down
7 changes: 6 additions & 1 deletion tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,12 +834,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label

@pytest.mark.parametrize("together", [True, False])
@pytest.mark.parametrize("num_vals", [1, 2])
def test_plot_method_collection(together, num_vals):
@pytest.mark.parametrize(
("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")]
)
def test_plot_method_collection(together, num_vals, prefix, postfix):
"""Test the plot method of metric collection."""
m_collection = MetricCollection(
BinaryAccuracy(),
BinaryPrecision(),
BinaryRecall(),
prefix=prefix,
postfix=postfix,
)
if num_vals == 1:
m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))
Expand Down
Loading