Skip to content

Commit

Permalink
fix: MetricCollection did not copy the inner state of the metric in…
Browse files Browse the repository at this point in the history
… `ClasswiseWrapper` when computing group metrics (#2390)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

---------

Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 1951a06)
  • Loading branch information
daniel-code authored and Borda committed Mar 16, 2024
1 parent 4e85685 commit f0b0175
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
23 changes: 23 additions & 0 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
Expand All @@ -20,6 +21,9 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric

if typing.TYPE_CHECKING:
from torch.nn import Module

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClasswiseWrapper.plot"]

Expand Down Expand Up @@ -209,3 +213,22 @@ def plot(
"""
return self._plot(val, ax)

def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]:
return getattr(self.metric, name)

return super().__getattr__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute to classwise wrapper."""
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]:
# update ``_update_count`` and ``_computed`` of internal metric to prevent warning.
setattr(self.metric, name, value)
40 changes: 39 additions & 1 deletion tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
import torch
from torchmetrics import Metric, MetricCollection
from torchmetrics import ClasswiseWrapper, Metric, MetricCollection
from torchmetrics.classification import (
BinaryAccuracy,
MulticlassAccuracy,
Expand Down Expand Up @@ -540,6 +540,44 @@ def test_compute_group_define_by_user():
assert m.compute()


def test_classwise_wrapper_compute_group():
"""Check that user can provide compute groups."""
classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy")
classwise_recall = ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall")
classwise_precision = ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision")

m = MetricCollection(
{
"accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy"),
"recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall"),
"precision": ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision"),
},
compute_groups=[["accuracy", "recall", "precision"]],
)

# Check that we are not going to check the groups in the first update
assert m._groups_checked
assert m.compute_groups == {0: ["accuracy", "recall", "precision"]}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))

expected = {
**classwise_accuracy(preds, target),
**classwise_recall(preds, target),
**classwise_precision(preds, target),
}

m.update(preds, target)
res = m.compute()

for key in expected:
assert torch.allclose(res[key], expected[key])

# check metric state_dict
m.state_dict()


def test_compute_on_different_dtype():
"""Check that extraction of compute groups are robust towards difference in dtype."""
m = MetricCollection([
Expand Down

0 comments on commit f0b0175

Please sign in to comment.