Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: MetricCollection did not copy the inner state of the metric in ClasswiseWrapper when computing group metrics #2390

Merged
merged 7 commits into from
Mar 5, 2024
23 changes: 23 additions & 0 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
Expand All @@ -20,6 +21,9 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric

if typing.TYPE_CHECKING:
from torch.nn import Module

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClasswiseWrapper.plot"]

Expand Down Expand Up @@ -209,3 +213,22 @@ def plot(

"""
return self._plot(val, ax)

def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
# return state from self.metric
if name in ["tp", "fp", "fn", "tn"]:
return getattr(self.metric, name)

return super().__getattr__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute to classwise wrapper."""
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn", "_update_count", "_computed"]:
# update ``_update_count`` and ``_computed`` of internal metric to prevent warning.
setattr(self.metric, name, value)
40 changes: 39 additions & 1 deletion tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
import torch
from torchmetrics import Metric, MetricCollection
from torchmetrics import ClasswiseWrapper, Metric, MetricCollection
from torchmetrics.classification import (
BinaryAccuracy,
MulticlassAccuracy,
Expand Down Expand Up @@ -540,6 +540,44 @@ def test_compute_group_define_by_user():
assert m.compute()


def test_classwise_wrapper_compute_group():
"""Check that user can provide compute groups."""
classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy")
classwise_recall = ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall")
classwise_precision = ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision")

m = MetricCollection(
{
"accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy"),
"recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), prefix="recall"),
"precision": ClasswiseWrapper(MulticlassPrecision(num_classes=3, average=None), prefix="precision"),
},
compute_groups=[["accuracy", "recall", "precision"]],
)

# Check that we are not going to check the groups in the first update
assert m._groups_checked
assert m.compute_groups == {0: ["accuracy", "recall", "precision"]}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))

expected = {
**classwise_accuracy(preds, target),
**classwise_recall(preds, target),
**classwise_precision(preds, target),
}

m.update(preds, target)
res = m.compute()

for key in expected:
assert torch.allclose(res[key], expected[key])

# check metric state_dict
m.state_dict()


def test_compute_on_different_dtype():
"""Check that extraction of compute groups are robust towards difference in dtype."""
m = MetricCollection([
Expand Down
Loading