Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricCollection compute groups never re-established after copy #2896

Open
alexrgilbert opened this issue Jan 6, 2025 · 0 comments
Open
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.6.x

Comments

@alexrgilbert
Copy link

🐛 Bug

self._compute_groups_create_state_ref()

After calling one of values, items, or __getitem__, with the argument copy = True, the references to metric states within compute groups are broken and actual copies are created. On the next call to update, these links are supposed to be re-established, however because of the current logic, this never occurs.

To Reproduce

import torch
from torchmetrics import PearsonCorrCoef, MetricCollection

# Initialize metric collection with two identical metrics
m1, m2 = PearsonCorrCoef(), PearsonCorrCoef()
# Add metrics to collection and enable compute groups
m12 = MetricCollection({'m1': m1, 'm2': m2}, compute_groups=True)
print(f'Generate random data')
x1, y1 = torch.randn(100), torch.randn(100)
print(f'Make first call to update')
m12.update(x1, y1)
print(f'We can see that a single compute group with both metrics was automatically created')
print(f'Compute groups: {m12.compute_groups}')
print(f'The metric knows that the states aren\'t copies of each other')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'And indeed they are stored at the same location')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'By making a call to `items()`, the references between states are broken')
_ = m12.items()
print(f'The metric now knows that the states are not the same')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'We can see that indeed they are not stored at the same location')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'Let\'s generate some more data...')
x2, y2 = torch.randn(100), torch.randn(100)
print(f'...and make a second call to update')
m12.update(x2, y2)
print(f'During update, the collection tries to restore the references between states...')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'...but it doesn\'t happen. We can see that the states are not actually the same')
print(f'States at same location? {getattr(m12, "m1").mean_x.data_ptr() == getattr(m12, "m2").mean_x.data_ptr()}')
print(f'Finally, if we generate one more data point...')
x3, y3 = torch.randn(100), torch.randn(100)
print(f'...and make a final call to update')
m12.update(x3, y3)
print(f'During update, the collection thinks it has references between states...')
print(f'm12._state_is_copy: {m12._state_is_copy}')
print(f'...but it doesn\'t so now the state values don\'t even match! Yikes!')
print(f'States have matching values? {m12._equal_metric_states(getattr(m12, "m1"), getattr(m12, "m2"))}')

Expected Behavior / Solution

The issue is caused by logic in _compute_groups_create_state_ref which only establishes references if _state_is_copy is False. However, as seen in the line referenced above, _state_is_copy isn't set to False until after the call to _compute_groups_create_state_ref. The order should be switched.

Environment

  • TorchMetrics version (if build from source, add commit SHA): 1.6.0
  • Python & PyTorch Version (e.g., 1.0): Python 3.10.12, PyTorch 2.6.0.dev20241201+cu124
  • Any other relevant information such as OS (e.g., Linux): Ubuntu
@Borda Borda added bug / fix Something isn't working v1.6.x help wanted Extra attention is needed labels Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.6.x
Projects
None yet
Development

No branches or pull requests

2 participants