From 1951a06fc914a26152e635f62aa8b32399d4c700 Mon Sep 17 00:00:00 2001 From: Su YR Date: Tue, 5 Mar 2024 17:17:21 +0800 Subject: [PATCH] fix: `MetricCollection` did not copy the inner state of the metric in `ClasswiseWrapper` when computing group metrics (#2390) * fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics Issue Link: https://github.com/Lightning-AI/torchmetrics/issues/2389 * fix: set _persistent and _reductions be same as internal metric * test: check metric state_dict wrapped in `ClasswiseWrapper` --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/wrappers/classwise.py | 23 +++++++++++++ tests/unittests/bases/test_collections.py | 40 ++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 3c8d6621bc2..698d0f51848 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import typing from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor @@ -20,6 +21,9 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE from torchmetrics.wrappers.abstract import WrapperMetric +if typing.TYPE_CHECKING: + from torch.nn import Module + if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["ClasswiseWrapper.plot"] @@ -209,3 +213,22 @@ def plot( """ return self._plot(val, ax) + + def __getattr__(self, name: str) -> Union[Tensor, "Module"]: + """Get attribute from classwise wrapper.""" + # return state from self.metric + if name in ["tp", "fp", "fn", "tn"]: + return getattr(self.metric, name) + + return super().__getattr__(name) + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute to classwise wrapper.""" + super().__setattr__(name, value) + if name == "metric": + self._defaults = self.metric._defaults + self._persistent = self.metric._persistent + self._reductions = self.metric._reductions + if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]: + # update ``_update_count`` and ``_computed`` of internal metric to prevent warning. + setattr(self.metric, name, value) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 9e4ac4a5897..16c95fc879a 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -17,7 +17,7 @@ import pytest import torch -from torchmetrics import Metric, MetricCollection +from torchmetrics import ClasswiseWrapper, Metric, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, MulticlassAccuracy, @@ -540,6 +540,44 @@ def test_compute_group_define_by_user(): assert m.compute() +def test_classwise_wrapper_compute_group(): + """Check that user can provide compute groups.""" + classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy") + classwise_recall = ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall") + classwise_precision = ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision") + + m = MetricCollection( + { + "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy"), + "recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall"), + "precision": ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision"), + }, + compute_groups=[["accuracy", "recall", "precision"]], + ) + + # Check that we are not going to check the groups in the first update + assert m._groups_checked + assert m.compute_groups == {0: ["accuracy", "recall", "precision"]} + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + + expected = { + **classwise_accuracy(preds, target), + **classwise_recall(preds, target), + **classwise_precision(preds, target), + } + + m.update(preds, target) + res = m.compute() + + for key in expected: + assert torch.allclose(res[key], expected[key]) + + # check metric state_dict + m.state_dict() + + def test_compute_on_different_dtype(): """Check that extraction of compute groups are robust towards difference in dtype.""" m = MetricCollection([