Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect caching (m._compute) of metrics inside a MetricCollection if compute_groups are used and .compute is called twice #2570

Closed
relativityhd opened this issue May 28, 2024 · 1 comment · Fixed by #2571
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@relativityhd
Copy link
Contributor

relativityhd commented May 28, 2024

🐛 Bug

Hello Team! I have found a bug which I already tracked down to is related to the caching attribute _compute of a metric in a metric collection. If the .compute() method is called twice then the second call returns different (and false) results if compute_groups are used. That is because the first .compute() call will store result of the first metric in the _compute cache for all other metrics of that compute group, which is then returned in the second call.

To Reproduce

Steps to reproduce the behavior...

Code sample
import torch
from torchmetrics import AUROC, ROC, Recall, F1Score, MetricCollection

metrics = MetricCollection({
    "auroc": AUROC(task='binary'),
    "roc": ROC(task='binary'),
    "recall": Recall(task='binary'),
    "f1": F1Score(task='binary')
})

y_true = torch.tensor([1, 0, 0, 1])
y_pred = torch.tensor([0.6, 0.2, 0.4, 0.2])

for batch in range(10):
    metrics.update(y_pred, y_true)

print("First compute call:")
print(metrics.compute())
print("Second (and all consecutive) compute call:")
print(metrics.compute())
print("Compute Groups:")
print(metrics.compute_groups)

This results in:

First compute call:
{
    'auroc': tensor(0.6250),
    'f1': tensor(0.6667),
    'recall': tensor(0.5000),
    'roc': (
        tensor([0.0000, 0.0000, 0.5000, 1.0000]),
        tensor([0.0000, 0.5000, 0.5000, 1.0000]),
        tensor([1.0000, 0.6000, 0.4000, 0.2000])
    )
}
Second (and all consecutive) compute call:
{'auroc': tensor(0.6250), 'f1': tensor(0.6667), 'recall': tensor(0.6667), 'roc': tensor(0.6250)}
Compute Groups:
{0: ['auroc', 'roc'], 1: ['f1', 'recall']}

Expected behavior

I expect the output of above code to be
First compute call:
{
    'auroc': tensor(0.6250),
    'f1': tensor(0.6667),
    'recall': tensor(0.5000),
    'roc': (
        tensor([0.0000, 0.0000, 0.5000, 1.0000]),
        tensor([0.0000, 0.5000, 0.5000, 1.0000]),
        tensor([1.0000, 0.6000, 0.4000, 0.2000])
    )
}
Second (and all consecutive) compute call:
{
    'auroc': tensor(0.6250),
    'f1': tensor(0.6667),
    'recall': tensor(0.5000),
    'roc': (
        tensor([0.0000, 0.0000, 0.5000, 1.0000]),
        tensor([0.0000, 0.5000, 0.5000, 1.0000]),
        tensor([1.0000, 0.6000, 0.4000, 0.2000])
    )
}
Compute Groups:
{0: ['auroc', 'roc'], 1: ['f1', 'recall']}

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 1.4.0.post0 installed via rye add
  • Python & PyTorch Version (e.g., 1.0): Python 3.10.14 & PyTorch 2.2.0
  • Any other relevant information such as OS (e.g., Linux): Ubuntu 22 on WSL

Additional context

I already tracked down the respective row:

mi._computed = deepcopy(m0._computed) if copy else m0._computed

It was introduced by the PR #2211 . Since they also fixed something with that, simply reverting the change would not help. I currently try to find an other solution for this.

But Tobi, why do you want to have two compute calls in the first place?

The reason why I want to use two seperatly .compute() calls is that in my current code structure I want to log my metrics in two different functions which have access to the same metrics object. The use case is actually pretty close to the example I provided above. Of course I could use a workaround by just storing the result by myself and handle it in my respective functions, but this would make my code less readable. Therefore a fix for this Issue would be a nice to have for me. However, other people could run into the same problem without noticing, since for some e.g. scalar metrics the code would run as usual, just with the wrong outputs.

@relativityhd relativityhd added bug / fix Something isn't working help wanted Extra attention is needed labels May 28, 2024
@relativityhd
Copy link
Contributor Author

relativityhd commented May 28, 2024

  1. I found out that test for Fix MetricCollection with repeated compute calls #2211 are not properly implemented: Deleting the added row from that PR also results in all tests pass, validating it manually, however, shows that the bug Fix MetricCollection with repeated compute calls #2211 fixed exists if the respective row is not present.
  2. I also found out that some contents of compute could be references instead of deepcopies, which results in even weirder side effects.

I fixed all 3 things (the original bug and the two new findings). I will open a PR for this in a sec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant