Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultioutputWrapper does not reset cleanly #1436

Closed
phschoepf opened this issue Jan 9, 2023 · 1 comment · Fixed by #1460
Closed

MultioutputWrapper does not reset cleanly #1436

phschoepf opened this issue Jan 9, 2023 · 1 comment · Fixed by #1460
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@phschoepf
Copy link

phschoepf commented Jan 9, 2023

🐛 Bug

Calling MultioutputWrapper.compute() after MultioutputWrapper.reset() returns old metrics that should have been cleared by the reset.

To Reproduce

Code sample

import torch
import torchmetrics

base_metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=2)
cf = torchmetrics.MultioutputWrapper(base_metric, num_outputs=2)

cf(torch.tensor([[0,0]]), torch.tensor([[0,0]]))
print("First result: ", cf.compute())

cf.reset()

cf(torch.tensor([[1,1]]), torch.tensor([[0,0]]))
print("Second result: ", cf.compute())

Output:

First result:  [tensor([[1, 0], [0, 0]]), tensor([[1, 0], [0, 0]])]
Second result:  [tensor([[1, 0], [0, 0]]), tensor([[1, 0], [0, 0]])]

The old output is returned even after resetting and entering new data. If the fist metric computation is omitted, the second metric is as expected.

Importantly, this bug only occurs when using forward() to enter data, while update() works as expected.

Expected behavior

The result of the second computation should be independent of the first. Furthermore, forward and update should produce the same state as specified in the docs.

Environment

  • torchmetrics 0.10.3, installed from pypi
  • Python 3.8.9

Attempts to fix

Adding super().reset() (as done in e.g. the minmax wrapper) at the top of the reset method seems to fix the bug.
https://github.com/Lightning-AI/metrics/blob/7b505ff1a3b88181bef2b0cdfa21ec593dcda3ff/src/torchmetrics/wrappers/multioutput.py#L133

@phschoepf phschoepf added bug / fix Something isn't working help wanted Extra attention is needed labels Jan 9, 2023
@github-actions
Copy link

github-actions bot commented Jan 9, 2023

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants