Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MetricCollection when input are metrics that return dicts with same keywords #2027

Merged
merged 11 commits into from
Aug 28, 2023
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)
- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019))


- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017))


- Fixed bug in `MetricCollection` when used with multiple metrics that return dicts with same keys ([#2027](https://github.com/Lightning-AI/torchmetrics/pull/2027))


- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))


- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)
- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028))


## [1.1.0] - 2023-08-22
Expand Down
18 changes: 14 additions & 4 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import allclose
from torchmetrics.utilities.data import _flatten_dict, allclose
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -334,17 +334,27 @@ def _compute_and_reduce(
res = m(*args, **m._filter_kwargs(**kwargs))
else:
raise ValueError("method_name should be either 'compute' or 'forward', but got {method_name}")
result[k] = res

_, duplicates = _flatten_dict(result)

flattened_results = {}
for k, res in result.items():
if isinstance(res, dict):
for key, v in res.items():
# if duplicates of keys we need to add unique prefix to each key
if duplicates:
stripped_k = k.replace(getattr(m, "prefix", ""), "")
stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
key = f"{stripped_k}_{key}"
if hasattr(m, "prefix") and m.prefix is not None:
key = f"{m.prefix}{key}"
if hasattr(m, "postfix") and m.postfix is not None:
key = f"{key}{m.postfix}"
result[key] = v
flattened_results[key] = v
else:
result[k] = res
return {self._set_name(k): v for k, v in result.items()}
flattened_results[k] = res
return {self._set_name(k): v for k, v in flattened_results.items()}

def reset(self) -> None:
"""Call reset for each metric sequentially."""
Expand Down
13 changes: 9 additions & 4 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from lightning_utilities import apply_to_collection
Expand Down Expand Up @@ -60,16 +60,21 @@ def _flatten(x: Sequence) -> list:
return [item for sublist in x for item in sublist]


def _flatten_dict(x: Dict) -> Dict:
"""Flatten dict of dicts into single dict."""
def _flatten_dict(x: Dict) -> Tuple[Dict, bool]:
"""Flatten dict of dicts into single dict and checking for duplicates in keys along the way."""
new_dict = {}
duplicates = False
for key, value in x.items():
if isinstance(value, dict):
for k, v in value.items():
if k in new_dict:
duplicates = True
new_dict[k] = v
else:
if key in new_dict:
duplicates = True
new_dict[key] = value
return new_dict
return new_dict, duplicates


def to_onehot(
Expand Down
30 changes: 26 additions & 4 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,33 @@ def test_nested_collections(input_collections):
assert "valmetrics/micro_MulticlassPrecision" in val


def test_double_nested_collections():
@pytest.mark.parametrize(
("base_metrics", "expected"),
[
(
DummyMetricMultiOutputDict(),
(
"prefix2_prefix1_output1_postfix1_postfix2",
"prefix2_prefix1_output2_postfix1_postfix2",
),
),
(
{"metric1": DummyMetricMultiOutputDict(), "metric2": DummyMetricMultiOutputDict()},
(
"prefix2_prefix1_metric1_output1_postfix1_postfix2",
"prefix2_prefix1_metric1_output2_postfix1_postfix2",
"prefix2_prefix1_metric2_output1_postfix1_postfix2",
"prefix2_prefix1_metric2_output2_postfix1_postfix2",
),
),
],
)
def test_double_nested_collections(base_metrics, expected):
"""Test that double nested collections gets flattened to a single collection."""
collection1 = MetricCollection([DummyMetricMultiOutputDict()], prefix="prefix1_", postfix="_postfix1")
collection1 = MetricCollection(base_metrics, prefix="prefix1_", postfix="_postfix1")
collection2 = MetricCollection([collection1], prefix="prefix2_", postfix="_postfix2")
x = torch.randn(10).sum()
val = collection2(x)
assert "prefix2_prefix1_output1_postfix1_postfix2" in val
assert "prefix2_prefix1_output2_postfix1_postfix2" in val

for key in val:
assert key in expected
5 changes: 3 additions & 2 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def test_flatten_list():
def test_flatten_dict():
"""Check that _flatten_dict utility function works as expected."""
inp = {"a": {"b": 1, "c": 2}, "d": 3}
out = _flatten_dict(inp)
assert out == {"b": 1, "c": 2, "d": 3}
out_dict, out_dup = _flatten_dict(inp)
assert out_dict == {"b": 1, "c": 2, "d": 3}
assert out_dup is False


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
Expand Down