Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MetricCollection when input are metrics that return dicts with same keywords #2027

Merged
merged 11 commits into from
Aug 28, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


- Fixed bug in `MetricCollection` when used with multiple metrics that return dicts with same keys ([#2027](https://github.com/Lightning-AI/torchmetrics/pull/2027)


## [1.1.0] - 2023-08-22

### Added
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ def _compute_and_reduce(

if isinstance(res, dict):
for key, v in res.items():
stripped_k = k.replace(m.prefix, "").replace(m.postfix, "")
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
key = f"{stripped_k}_{key}"
if hasattr(m, "prefix") and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
Expand Down
30 changes: 26 additions & 4 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,33 @@ def test_nested_collections(input_collections):
assert "valmetrics/micro_MulticlassPrecision" in val


def test_double_nested_collections():
@pytest.mark.parametrize(
("base_metrics", "expected"),
[
(
DummyMetricMultiOutputDict(),
(
"prefix2_prefix1_DummyMetricMultiOutputDict_output1_postfix1_postfix2",
"prefix2_prefix1_DummyMetricMultiOutputDict_output2_postfix1_postfix2",
),
),
(
{"metric1": DummyMetricMultiOutputDict(), "metric2": DummyMetricMultiOutputDict()},
(
"prefix2_prefix1_metric1_output1_postfix1_postfix2",
"prefix2_prefix1_metric1_output2_postfix1_postfix2",
"prefix2_prefix1_metric2_output1_postfix1_postfix2",
"prefix2_prefix1_metric2_output2_postfix1_postfix2",
),
),
],
)
def test_double_nested_collections(base_metrics, expected):
"""Test that double nested collections gets flattened to a single collection."""
collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1")
collection1 = MetricCollection(base_metrics, prefix="prefix1_", postfix="_postfix1")
collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2")
x = torch.randn(10).sum()
val = collection2(x)
assert "prefix2_prefix1_output1_postfix1_postfix2" in val
assert "prefix2_prefix1_output2_postfix1_postfix2" in val

for key in val:
assert key in expected