diff --git a/CHANGELOG.md b/CHANGELOG.md index 66f6fe78856..da0b81906e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042)) +- Fixed bug related to `MetricCollection` used with custom metrics have `prefix`/`postfix` attributes ([#2070](https://github.com/Lightning-AI/torchmetrics/pull/2070)) + ## [1.1.1] - 2023-08-29 ### Added diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index c7077f537f5..c1bbce0adb7 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -339,7 +339,8 @@ def _compute_and_reduce( _, duplicates = _flatten_dict(result) flattened_results = {} - for k, res in result.items(): + for k, m in self.items(keep_base=True, copy_state=False): + res = result[k] if isinstance(res, dict): for key, v in res.items(): # if duplicates of keys we need to add unique prefix to each key @@ -347,9 +348,9 @@ def _compute_and_reduce( stripped_k = k.replace(getattr(m, "prefix", ""), "") stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "") key = f"{stripped_k}_{key}" - if hasattr(m, "prefix") and m.prefix is not None: + if getattr(m, "_from_collection", None) and m.prefix is not None: key = f"{m.prefix}{key}" - if hasattr(m, "postfix") and m.postfix is not None: + if getattr(m, "_from_collection", None) and m.postfix is not None: key = f"{key}{m.postfix}" flattened_results[key] = v else: @@ -425,6 +426,7 @@ def add_metrics( for k, v in metric.items(keep_base=False): v.postfix = metric.postfix v.prefix = metric.prefix + v._from_collection = True self[f"{name}_{k}"] = v elif isinstance(metrics, Sequence): for metric in metrics: @@ -442,6 +444,7 @@ def add_metrics( for k, v in metric.items(keep_base=False): v.postfix = metric.postfix v.prefix = metric.prefix + v._from_collection = True self[k] = v else: raise ValueError( diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 834f764ff81..ebda89de83b 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -644,3 +644,36 @@ def test_double_nested_collections(base_metrics, expected): for key in val: assert key in expected + + +def test_with_custom_prefix_postfix(): + """Test that metric colection does not clash with custom prefix and postfix in users metrics. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2065 + + """ + + class CustomAccuracy(MulticlassAccuracy): + prefix = "my_prefix" + postfix = "my_postfix" + + def compute(self): + value = super().compute() + return {f"{self.prefix}/accuracy/{self.postfix}": value} + + class CustomPrecision(MulticlassAccuracy): + prefix = "my_prefix" + postfix = "my_postfix" + + def compute(self): + value = super().compute() + return {f"{self.prefix}/precision/{self.postfix}": value} + + metrics = MetricCollection([CustomAccuracy(num_classes=2), CustomPrecision(num_classes=2)]) + + # Update metrics with current batch + res = metrics(torch.tensor([1, 0, 0, 1]), torch.tensor([1, 0, 0, 0])) + + # Print the calculated metrics + assert "my_prefix/accuracy/my_postfix" in res + assert "my_prefix/precision/my_postfix" in res