Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics #2389

Closed
daniel-code opened this issue Feb 16, 2024 · 1 comment · Fixed by #2390 or #2424
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x

Comments

@daniel-code
Copy link
Contributor

daniel-code commented Feb 16, 2024

🐛 Bug

MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics.

To Reproduce

Steps to reproduce the behavior...

Code sample
import torch
from lightning import seed_everything
from torchmetrics import MetricCollection, ClasswiseWrapper
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score

seed_everything(42)

random_pred = torch.rand((10, 3))
pred = torch.softmax(random_pred, dim=-1)
pred_class = torch.argmax(pred, dim=-1)
target = torch.randint(0, 3, size=(10,))

multiclass_acc = MulticlassAccuracy(
    num_classes=3,
    average=None,
)
print("multiclass_acc:", multiclass_acc(pred, target))

multiclass_f1 = MulticlassF1Score(
    num_classes=3,
    average=None,
)
print("multiclass_f1:", multiclass_f1(pred, target))

mc = MetricCollection(
    {
        "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
        "f1": ClasswiseWrapper(MulticlassF1Score(num_classes=3, average=None)),
    },
    compute_groups=[
        ["accuracy", "f1"],
    ],
)

print("MetricCollection.forward:", mc(pred, target))
mc.reset()
mc.update(pred, target)
print("MetricCollection.update&compute:", mc.compute())
Output
Seed set to 42
site-packages\torchmetrics\utilities\prints.py:43: UserWarning: The ``compute`` method of metric MulticlassF1Score was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
multiclass_acc: tensor([0.3333, 0.5000, 0.3333])
multiclass_f1: tensor([0.4000, 0.4444, 0.3333])
MetricCollection.forward: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.4000), 'multiclassf1score_1': tensor(0.4444), 'multiclassf1score_2': tensor(0.3333)}
MetricCollection.update&compute: {'multiclassaccuracy_0': tensor(0.3333), 'multiclassaccuracy_1': tensor(0.5000), 'multiclassaccuracy_2': tensor(0.3333), 'multiclassf1score_0': tensor(0.), 'multiclassf1score_1': tensor(0.), 'multiclassf1score_2': tensor(0.)}

Expected behavior

The metrics multiclassf1score_0, multiclassf1score_1, and multiclassf1score_2 of MetricCollection.compute should be the same as MetricCollection.forward or the result of each metric.

Solution 1

https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/collections.py#L305

class MetricCollection:
    def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
        """Create reference between metrics in the same compute group.

        Args:
            copy: If `True` the metric state will between members will be copied instead
                of just passed by reference

        """
        if not self._state_is_copy:
            for cg in self._groups.values():
                m0 = getattr(self, cg[0])
                for i in range(1, len(cg)):
                    mi = getattr(self, cg[i])
                    for state in m0._defaults:
                        m0_state = getattr(m0, state)
                        # Determine if we just should set a reference or a full copy
                        if isinstance(mi, ClasswiseWrapper):  # << Added
                            setattr(mi.metric, state, deepcopy(m0_state) if copy else m0_state) # << Added
                        setattr(mi, state, deepcopy(m0_state) if copy else m0_state)

                    if isinstance(mi, ClasswiseWrapper): # << Added
                        mi.metric._update_count = deepcopy(m0._update_count) if copy else m0._update_count # << Added
                        mi.metric._computed = deepcopy(m0._computed) if copy else m0._computed # << Added
                    mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
                    mi._computed = deepcopy(m0._computed) if copy else m0._computed
        self._state_is_copy = copy
Solution 2

https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/wrappers/classwise.py#L27

class ClasswiseWrapper:
    def __getattr__(self, name: str):
        # return state from self.metric
        if name in ["tp", "fp", "fn", "tn"]:   # <<Added
            return getattr(self.metric, name)  # <<Added

        return super().__getattr__(name)

    def __setattr__(self, name: str, value: Any) -> None:
        if hasattr(self, "metric") and name in ["tp", "fp", "fn", "tn"]:  # <<Added
            setattr(self.metric, name, value)                             # <<Added
        else:                                                             # <<Added
            super().__setattr__(name, value)
            if name == "metric":                                          # <<Added
                self._defaults = self.metric._defaults                    # <<Added

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 1.3.1
  • Python & PyTorch Version (e.g., 1.0):
    • Python: 3.8
    • Pytorch: 2.1.1
  • Any other relevant information such as OS (e.g., Linux): Windows 11

Additional context

@daniel-code daniel-code added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 16, 2024
@Borda Borda added the v1.3.x label Feb 16, 2024
@daniel-code
Copy link
Contributor Author

daniel-code commented Feb 16, 2024

Update Solution 2, which overrides the __getattr__ and __setattr__ of ClasswiseWrapper

daniel-code added a commit to daniel-code/torchmetrics that referenced this issue Feb 16, 2024
SkafteNicki pushed a commit that referenced this issue Mar 5, 2024
… `ClasswiseWrapper` when computing group metrics (#2390)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

---------

Co-authored-by: Jirka Borovec <[email protected]>
Borda added a commit that referenced this issue Mar 5, 2024
…eneral (#2424)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

* refactor: make __getattr__ and __setattr__ of ClasswiseWrapper more general

* chlog

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
Borda pushed a commit that referenced this issue Mar 16, 2024
… `ClasswiseWrapper` when computing group metrics (#2390)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

---------

Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 1951a06)
Borda pushed a commit that referenced this issue Mar 18, 2024
… `ClasswiseWrapper` when computing group metrics (#2390)

* fix: MetricCollection did not copy inner state of metric in ClasswiseWrapper when computing groups metrics

Issue Link: #2389

* fix: set _persistent and _reductions be same as internal metric

* test: check metric state_dict wrapped in `ClasswiseWrapper`

---------

Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 1951a06)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x
Projects
None yet
2 participants