diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a2be2c4936..f02b5d16dbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed behaviour of `confusionmatrix` for multilabel data to better match `multilabel_confusion_matrix` from sklearn ([#134](https://github.com/PyTorchLightning/metrics/pull/134)) - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) - Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163)) +- Metrics passed as dict to `MetricCollection` will now always be in deterministic order ([#173](https://github.com/PyTorchLightning/metrics/pull/173)) ### Deprecated diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 68206debd66..b9f8e6955b7 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -156,3 +156,12 @@ def test_metric_collection_prefix_arg(tmpdir): out = new_metric_collection(5) for name in names: assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method' + + +def test_metric_collection_same_order(): + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + col1 = MetricCollection({"a": m1, "b": m2}) + col2 = MetricCollection({"b": m2, "a": m1}) + for k1, k2 in zip(col1.keys(), col2.keys()): + assert k1 == k2 diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 3f0e0933c69..abfceaa1280 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -35,6 +35,7 @@ class MetricCollection(nn.ModuleDict): * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. + Note that the keys in the output dict will be sorted alphabetically. prefix: a string to append in front of the keys of the output dict @@ -78,7 +79,9 @@ def __init__( super().__init__() if isinstance(metrics, dict): # Check all values are metrics - for name, metric in metrics.items(): + # Make sure that metrics are added in deterministic order + for name in sorted(metrics.keys()): + metric = metrics[name] if not isinstance(metric, Metric): raise ValueError( f"Value {metric} belonging to key {name}"