You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After calling one of values, items, or __getitem__, with the argument copy = True, the references to metric states within compute groups are broken and actual copies are created. On the next call to update, these links are supposed to be re-established, however because of the current logic, this never occurs.
To Reproduce
import torch
from torchmetrics import PearsonCorrCoef, MetricCollection
# Initialize metric collection with two identical metrics
m1, m2 = PearsonCorrCoef(), PearsonCorrCoef()
# Add metrics to collection and enable compute groups
m12 = MetricCollection({'m1': m1, 'm2': m2}, compute_groups=True)
print(f'Generate random data')
x1, y1 = torch.randn(100), torch.randn(100)
print(f'Make first call to update')
m12.update(x1, y1)
print(f'We can see that a single compute group with both metrics was automatically created')
print(f'Compute groups: {m12.compute_groups}')
print(f'The metric knows that the states aren\'t copies of each other')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'And indeed they are stored at the same location')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'By making a call to `items()`, the references between states are broken')
_ = m12.items()
print(f'The metric now knows that the states are not the same')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'We can see that indeed they are not stored at the same location')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'Let\'s generate some more data...')
x2, y2 = torch.randn(100), torch.randn(100)
print(f'...and make a second call to update')
m12.update(x2, y2)
print(f'During update, the collection tries to restore the references between states...')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'...but it doesn\'t happen. We can see that the states are not actually the same')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'Finally, if we generate one more data point...')
x3, y3 = torch.randn(100), torch.randn(100)
print(f'...and make a final call to update')
m12.update(x3, y3)
print(f'During update, the collection thinks it has references between states...')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'...but it doesn\'t so now the state values don\'t even match! Yikes!')
print(f'States have matching values? {m12._equal_metric_states(getattr(m12, "m1"), getattr(m12, "m2"))}')
Expected Behavior / Solution
The issue is caused by logic in _compute_groups_create_state_ref which only establishes references if _state_is_copy is False. However, as seen in the line referenced above, _state_is_copy isn't set to False until after the call to _compute_groups_create_state_ref. The order should be switched.
Environment
TorchMetrics version (if build from source, add commit SHA): 1.6.0
Python & PyTorch Version (e.g., 1.0): Python 3.10.12, PyTorch 2.6.0.dev20241201+cu124
Any other relevant information such as OS (e.g., Linux): Ubuntu
The text was updated successfully, but these errors were encountered:
🐛 Bug
torchmetrics/src/torchmetrics/collections.py
Line 251 in f7235c9
After calling one of
values
,items
, or__getitem__
, with the argumentcopy = True
, the references to metric states within compute groups are broken and actual copies are created. On the next call toupdate
, these links are supposed to be re-established, however because of the current logic, this never occurs.To Reproduce
Expected Behavior / Solution
The issue is caused by logic in
_compute_groups_create_state_ref
which only establishes references if_state_is_copy
isFalse
. However, as seen in the line referenced above,_state_is_copy
isn't set toFalse
until after the call to_compute_groups_create_state_ref
. The order should be switched.Environment
The text was updated successfully, but these errors were encountered: