-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pearson _final_aggregation
modifies states in place (+ link out of date)
#2893
Comments
Hi! thanks for your contribution!, great first issue! |
_final_aggregation
modifies states in place (and link out of date)_final_aggregation
modifies states in place (+ link out of date)
One more follow up issue—in my recent tests I encountered another bug(?). It occurs when the first two devices don't have any data for certain outputs, such that the sum of observations is zero. In this case, as you iterate over devices, the variance combination will become Below is a simple script to reproduce the issue, using the proposed new import torch
# Simulate Pearson metric on `N_DEVICES` devices
N_DEVICES = 4
# Simulate Pearson metric with `N_OUTPUTS` outputs
N_OUTPUTS = 5
def _final_aggregation(
means_x: torch.Tensor,
means_y: torch.Tensor,
vars_x: torch.Tensor,
vars_y: torch.Tensor,
corrs_xy: torch.Tensor,
nbs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Aggregate the statistics from multiple devices.
Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_
"""
if len(means_x) == 1:
return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
# count
nb = n1 + n2
# mean_x
mean_x = (n1 * mx1 + n2 * mx2) / nb
# mean_y
mean_y = (n1 * my1 + n2 * my2) / nb
# intermediates for running variances
n12_b = n1 * n2 / nb
delta_x = mx2 - mx1
delta_y = my2 - my1
# var_x
var_x = vx1 + vx2 + n12_b * delta_x ** 2
# var_y
var_y = vy1 + vy2 + n12_b * delta_y ** 2
# corr_xy
corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y
mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
return mean_x, mean_y, var_x, var_y, corr_xy, nb
mean_x = torch.randn(N_DEVICES, N_OUTPUTS)
mean_y = torch.randn(N_DEVICES, N_OUTPUTS)
var_x = torch.randn(N_DEVICES, N_OUTPUTS)
var_y = torch.randn(N_DEVICES, N_OUTPUTS)
corr_xy = torch.randn(N_DEVICES, N_OUTPUTS)
n_total = torch.randint(1, 100, (N_DEVICES, N_OUTPUTS))
for x in [mean_x, mean_y, var_x, var_y, corr_xy, n_total]:
x[:2] = 0
# Current
mean_x_, mean_y_, var_x_, var_y_, corr_xy_, n_total_ = _final_aggregation(mean_x, mean_y, var_x, var_y, corr_xy, n_total)
# Expected
mean_x__, mean_y__, var_x__, var_y__, corr_xy__, n_total__ = _final_aggregation(mean_x[2:], mean_y[2:], var_x[2:], var_y[2:], corr_xy[2:], n_total[2:])
print(f'Aggregated states:')
print(f'mean_x: {mean_x_} (expected: {mean_x__})\n')
print(f'mean_y: {mean_y_} (expected: {mean_y__})\n')
print(f'var_x: {var_x_} (expected: {var_x__})\n')
print(f'var_y: {var_y_} (expected: {var_y__})\n')
print(f'corr_xy: {corr_xy_} (expected: {corr_xy__})\n')
print(f'n_total: {n_total_} (expected: {n_total__})\n') A simple (maybe slightly tacky) way to fix this would be to fill in def _final_aggregation(
means_x: torch.Tensor,
means_y: torch.Tensor,
vars_x: torch.Tensor,
vars_y: torch.Tensor,
corrs_xy: torch.Tensor,
nbs: torch.Tensor,
eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Aggregate the statistics from multiple devices.
Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_
NOTE: We use `eps` to avoid division by zero when `n1` and `n2` are both zero. Generally,
the value of `eps` should not matter, as if `n1` and `n2` are both zero, all the states will
also be zero.
"""
if len(means_x) == 1:
return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
# count
# THIS LINE IS CHANGED!
nb = torch.where(torch.logical_or(n1, n2), n1 + n2, eps)
# mean_x
mean_x = (n1 * mx1 + n2 * mx2) / nb
# mean_y
mean_y = (n1 * my1 + n2 * my2) / nb
# intermediates for running variances
n12_b = n1 * n2 / nb
delta_x = mx2 - mx1
delta_y = my2 - my1
# var_x
var_x = vx1 + vx2 + n12_b * delta_x ** 2
# var_y
var_y = vy1 + vy2 + n12_b * delta_y ** 2
# corr_xy
corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y
mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
return mean_x, mean_y, var_x, var_y, corr_xy, nb |
🐛 Bug
torchmetrics/src/torchmetrics/regression/pearson.py
Line 29 in 714494b
The current implementation of the
_final_aggregation
function used byPearsonCorrCoef
was updated in #998 . I believe this update introduced a bug where the states of the metric are modified in-place, such that if the metric is initialized withcompute_with_cache = False
and used on multiple devices, subsequent calls tocompute
will return different (and inaccurate) results.Additionally, I believe the reference link in the docstring (see here) was also not updated to reflect these changes, and thus is outdated.
To Reproduce
Expected Behavior / Solution
A quick description of how to implement a parallel algorithm for aggregating running statistics for calculating Pearson correlation is given on the Wikipedia page for variance calculation algorithms. More detailed derivations and analysis can be found in papers by Chan et al. and Schubert et al. (which are cited by the Wikipedia article). I haven't been able to find a new source supporting the current implementation online, and while the current implementation can indeed be simplified into the equations provided by these references, it is a bit more (overly?) complex (and difficult to understand).
Proposed Solution
Below is a simplified implementation which 1) fixes the in-place bug described above, and 2) is more closely aligned with the source algorithms. According to tests on my data it matches the output of the current implementation, and also passes the torchmetrics unit tests. If there is no reason for the current implementation which I am overlooking (e.g., numerical precision, avoiding overflow), would it be worthwhile to replace with this simpler implementation (and also fixing the bug)?
Simple Verification
Below is a simple script verifying that the results of this implementation match that of the current algorithm. From this simple experiment, I find that the results match (up to 1e-8) when using double precision (i.e.
torch.float64
), but differ at just float precision (i.e.torch.float32
). However, upon further investigation, it seems the proposed implementation has better numeric stability, returning results at float precision which match the double precision results up to ~5e-3 (with 64 devices), while the current implementation has maximum errors (relative to the double precision result) on the order of 1e2.Environment
The text was updated successfully, but these errors were encountered: