Skip to content

Commit

Permalink
Bugfix for custom prefix/postfix and metric collection (#2070)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
(cherry picked from commit 6538d1a)
  • Loading branch information
SkafteNicki authored and Borda committed Sep 11, 2023
1 parent 5c2db0b commit 813f3c0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed performance issues in `RecallAtFixedPrecision` for large batch sizes ([#2042](https://github.com/Lightning-AI/torchmetrics/pull/2042))


- Fixed bug related to `MetricCollection` used with custom metrics have `prefix`/`postfix` attributes ([#2070](https://github.com/Lightning-AI/torchmetrics/pull/2070))

## [1.1.1] - 2023-08-29

### Added
Expand Down
9 changes: 6 additions & 3 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,18 @@ def _compute_and_reduce(
_, duplicates = _flatten_dict(result)

flattened_results = {}
for k, res in result.items():
for k, m in self.items(keep_base=True, copy_state=False):
res = result[k]
if isinstance(res, dict):
for key, v in res.items():
# if duplicates of keys we need to add unique prefix to each key
if duplicates:
stripped_k = k.replace(getattr(m, "prefix", ""), "")
stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
key = f"{stripped_k}_{key}"
if hasattr(m, "prefix") and m.prefix is not None:
if getattr(m, "_from_collection", None) and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
if getattr(m, "_from_collection", None) and m.postfix is not None:
key = f"{key}{m.postfix}"
flattened_results[key] = v
else:
Expand Down Expand Up @@ -425,6 +426,7 @@ def add_metrics(
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
v._from_collection = True
self[f"{name}_{k}"] = v
elif isinstance(metrics, Sequence):
for metric in metrics:
Expand All @@ -442,6 +444,7 @@ def add_metrics(
for k, v in metric.items(keep_base=False):
v.postfix = metric.postfix
v.prefix = metric.prefix
v._from_collection = True
self[k] = v
else:
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,36 @@ def test_double_nested_collections(base_metrics, expected):

for key in val:
assert key in expected


def test_with_custom_prefix_postfix():
"""Test that metric colection does not clash with custom prefix and postfix in users metrics.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2065
"""

class CustomAccuracy(MulticlassAccuracy):
prefix = "my_prefix"
postfix = "my_postfix"

def compute(self):
value = super().compute()
return {f"{self.prefix}/accuracy/{self.postfix}": value}

class CustomPrecision(MulticlassAccuracy):
prefix = "my_prefix"
postfix = "my_postfix"

def compute(self):
value = super().compute()
return {f"{self.prefix}/precision/{self.postfix}": value}

metrics = MetricCollection([CustomAccuracy(num_classes=2), CustomPrecision(num_classes=2)])

# Update metrics with current batch
res = metrics(torch.tensor([1, 0, 0, 1]), torch.tensor([1, 0, 0, 0]))

# Print the calculated metrics
assert "my_prefix/accuracy/my_postfix" in res
assert "my_prefix/precision/my_postfix" in res

0 comments on commit 813f3c0

Please sign in to comment.