Incorrect caching (m._compute
) of metrics inside a MetricCollection
if compute_groups
are used and .compute
is called twice
#2570
Labels
🐛 Bug
Hello Team! I have found a bug which I already tracked down to is related to the caching attribute
_compute
of a metric in a metric collection. If the.compute()
method is called twice then the second call returns different (and false) results if compute_groups are used. That is because the first.compute()
call will store result of the first metric in the_compute
cache for all other metrics of that compute group, which is then returned in the second call.To Reproduce
Steps to reproduce the behavior...
Code sample
This results in:
Expected behavior
I expect the output of above code to be
Environment
conda
,pip
, build from source): 1.4.0.post0 installed viarye add
Additional context
I already tracked down the respective row:
torchmetrics/src/torchmetrics/collections.py
Line 307 in 8d9d2ae
It was introduced by the PR #2211 . Since they also fixed something with that, simply reverting the change would not help. I currently try to find an other solution for this.
But Tobi, why do you want to have two compute calls in the first place?
The reason why I want to use two seperatly
.compute()
calls is that in my current code structure I want to log my metrics in two different functions which have access to the same metrics object. The use case is actually pretty close to the example I provided above. Of course I could use a workaround by just storing the result by myself and handle it in my respective functions, but this would make my code less readable. Therefore a fix for this Issue would be a nice to have for me. However, other people could run into the same problem without noticing, since for some e.g. scalar metrics the code would run as usual, just with the wrong outputs.The text was updated successfully, but these errors were encountered: