From 8189c59f1c0d82fd2aa1a76018b6f0a4e6a1fc4a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 26 Aug 2023 19:12:05 +0200 Subject: [PATCH 1/6] implementation --- src/torchmetrics/collections.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index d82646d7ce7..3069ce4993f 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -337,6 +337,8 @@ def _compute_and_reduce( if isinstance(res, dict): for key, v in res.items(): + stripped_k = k.replace(m.prefix, "").replace(m.postfix, "") + key = f"{stripped_k}_{key}" if hasattr(m, "prefix") and m.prefix is not None: key = f"{m.prefix}{key}" if hasattr(m, "postfix") and m.postfix is not None: From 936d470b38e59a8fe59dea29188776618baa782a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 26 Aug 2023 19:13:16 +0200 Subject: [PATCH 2/6] tests --- tests/unittests/bases/test_collections.py | 30 ++++++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index fafef8fb14d..f4ee88d4813 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -614,11 +614,33 @@ def test_nested_collections(input_collections): assert "valmetrics/micro_MulticlassPrecision" in val -def test_double_nested_collections(): +@pytest.mark.parametrize( + ("base_metrics", "expected"), + [ + ( + DummyMetricMultiOutputDict(), + ( + "prefix2_prefix1_DummyMetricMultiOutputDict_output1_postfix1_postfix2", + "prefix2_prefix1_DummyMetricMultiOutputDict_output2_postfix1_postfix2", + ), + ), + ( + {"metric1": DummyMetricMultiOutputDict(), "metric2": DummyMetricMultiOutputDict()}, + ( + "prefix2_prefix1_metric1_output1_postfix1_postfix2", + "prefix2_prefix1_metric1_output2_postfix1_postfix2", + "prefix2_prefix1_metric2_output1_postfix1_postfix2", + "prefix2_prefix1_metric2_output2_postfix1_postfix2", + ), + ), + ], +) +def test_double_nested_collections(base_metrics, expected): """Test that double nested collections gets flattened to a single collection.""" - collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1") + collection1 = MetricCollection(base_metrics, prefix="prefix1_", postfix="_postfix1") collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2") x = torch.randn(10).sum() val = collection2(x) - assert "prefix2_prefix1_output1_postfix1_postfix2" in val - assert "prefix2_prefix1_output2_postfix1_postfix2" in val + + for key in val: + assert key in expected From dad17b15e30494dc4b001fd5269a3c26306f6d96 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 19:29:02 +0200 Subject: [PATCH 3/6] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b23ce58d355..c5b544a6cf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) +- Fixed bug in `MetricCollection` when used with multiple metrics that return dicts with same keys ([#2027](https://github.com/Lightning-AI/torchmetrics/pull/2027) + + ## [1.1.0] - 2023-08-22 ### Added From 569e6026e9449665df7b639174ed4d160ce464e5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 28 Aug 2023 09:22:41 +0200 Subject: [PATCH 4/6] Update src/torchmetrics/collections.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/collections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 3069ce4993f..1512df3d85b 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -337,7 +337,8 @@ def _compute_and_reduce( if isinstance(res, dict): for key, v in res.items(): - stripped_k = k.replace(m.prefix, "").replace(m.postfix, "") + stripped_k = k.replace(getattr(m, "prefix", ""), "") + stripped_k = k.replace(getattr(m, "postfix", ""), "") key = f"{stripped_k}_{key}" if hasattr(m, "prefix") and m.prefix is not None: key = f"{m.prefix}{key}" From 90cd248e6fd199213dd2026165268179a7a9f822 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 28 Aug 2023 12:57:09 +0200 Subject: [PATCH 5/6] better backward compatibility --- src/torchmetrics/collections.py | 21 ++++++++++++++------- src/torchmetrics/utilities/data.py | 13 +++++++++---- tests/unittests/bases/test_collections.py | 4 ++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 1512df3d85b..c7077f537f5 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -23,7 +23,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.data import allclose +from torchmetrics.utilities.data import _flatten_dict, allclose from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -334,20 +334,27 @@ def _compute_and_reduce( res = m(*args, **m._filter_kwargs(**kwargs)) else: raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}") + result[k] = res + _, duplicates = _flatten_dict(result) + + flattened_results = {} + for k, res in result.items(): if isinstance(res, dict): for key, v in res.items(): - stripped_k = k.replace(getattr(m, "prefix", ""), "") - stripped_k = k.replace(getattr(m, "postfix", ""), "") - key = f"{stripped_k}_{key}" + # if duplicates of keys we need to add unique prefix to each key + if duplicates: + stripped_k = k.replace(getattr(m, "prefix", ""), "") + stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "") + key = f"{stripped_k}_{key}" if hasattr(m, "prefix") and m.prefix is not None: key = f"{m.prefix}{key}" if hasattr(m, "postfix") and m.postfix is not None: key = f"{key}{m.postfix}" - result[key] = v + flattened_results[key] = v else: - result[k] = res - return {self._set_name(k): v for k, v in result.items()} + flattened_results[k] = res + return {self._set_name(k): v for k, v in flattened_results.items()} def reset(self) -> None: """Call reset for each metric sequentially.""" diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index ebb81679a02..8e818a144f7 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from lightning_utilities import apply_to_collection @@ -60,16 +60,21 @@ def _flatten(x: Sequence) -> list: return [item for sublist in x for item in sublist] -def _flatten_dict(x: Dict) -> Dict: - """Flatten dict of dicts into single dict.""" +def _flatten_dict(x: Dict) -> Tuple[Dict, bool]: + """Flatten dict of dicts into single dict and checking for duplicates in keys along the way.""" new_dict = {} + duplicates = False for key, value in x.items(): if isinstance(value, dict): for k, v in value.items(): + if k in new_dict: + duplicates = True new_dict[k] = v else: + if key in new_dict: + duplicates = True new_dict[key] = value - return new_dict + return new_dict, duplicates def to_onehot( diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index f4ee88d4813..834f764ff81 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -620,8 +620,8 @@ def test_nested_collections(input_collections): ( DummyMetricMultiOutputDict(), ( - "prefix2_prefix1_DummyMetricMultiOutputDict_output1_postfix1_postfix2", - "prefix2_prefix1_DummyMetricMultiOutputDict_output2_postfix1_postfix2", + "prefix2_prefix1_output1_postfix1_postfix2", + "prefix2_prefix1_output2_postfix1_postfix2", ), ), ( From 46883d717b428cd77051db56ba210ddb3fb87883 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 28 Aug 2023 17:55:37 +0200 Subject: [PATCH 6/6] test --- tests/unittests/utilities/test_utilities.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index d0e38abadfb..ca05ce5f75b 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -113,8 +113,9 @@ def test_flatten_list(): def test_flatten_dict(): """Check that _flatten_dict utility function works as expected.""" inp = {"a": {"b": 1, "c": 2}, "d": 3} - out = _flatten_dict(inp) - assert out == {"b": 1, "c": 2, "d": 3} + out_dict, out_dup = _flatten_dict(inp) + assert out_dict == {"b": 1, "c": 2, "d": 3} + assert out_dup is False @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")