Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pearson _final_aggregation modifies states in place (+ link out of date) #2893

Open
alexrgilbert opened this issue Jan 3, 2025 · 2 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.6.x

Comments

@alexrgilbert
Copy link

alexrgilbert commented Jan 3, 2025

🐛 Bug

The current implementation of the _final_aggregation function used by PearsonCorrCoef was updated in #998 . I believe this update introduced a bug where the states of the metric are modified in-place, such that if the metric is initialized with compute_with_cache = False and used on multiple devices, subsequent calls to compute will return different (and inaccurate) results.

Additionally, I believe the reference link in the docstring (see here) was also not updated to reflect these changes, and thus is outdated.

To Reproduce

import torch
from torchmetrics.regression.pearson import _final_aggregation

# Simulate Pearson metric on `N_DEVICES` devices
N_DEVICES = 2
# Simulate Pearson metric with `N_OUTPUTS` outputs
N_OUTPUTS = 100
# Number of repeats to run the aggregation...more repeats leads to more drift
N_REPEATS = 2

mean_x = torch.randn(N_DEVICES, N_OUTPUTS)
mean_y = torch.randn(N_DEVICES, N_OUTPUTS)
var_x = torch.randn(N_DEVICES, N_OUTPUTS)
var_y = torch.randn(N_DEVICES, N_OUTPUTS)
corr_xy = torch.randn(N_DEVICES, N_OUTPUTS)
n_total = torch.randint(1, 100, (N_DEVICES, N_OUTPUTS))

_mean_x = mean_x.clone()
_mean_y = mean_y.clone()
_var_x = var_x.clone()
_var_y = var_y.clone()
_corr_xy = corr_xy.clone()
_n_total = n_total.clone()

for n in range(N_REPEATS):
    _final_aggregation(_mean_x, _mean_y, _var_x, _var_y, _corr_xy, _n_total)

if not torch.allclose(_mean_x, mean_x):
    diff = (_mean_x - mean_x).abs()
    print(f"Mean X drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")
if not torch.allclose(_mean_y, mean_y):
    diff = (_mean_y - mean_y).abs()
    print(f"Mean Y drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")
if not torch.allclose(_var_x, var_x):
    diff = (_var_x - var_x).abs()
    print(f"Var X drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")
if not torch.allclose(_var_y, var_y):
    diff = (_var_y - var_y).abs()
    print(f"Var Y drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")
if not torch.allclose(_corr_xy, corr_xy):
    diff = (_corr_xy - corr_xy).abs()
    print(f"Corr XY drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")
if not torch.allclose(_n_total, n_total):
    diff = (_n_total - n_total).abs()
    print(f"N Total drift: max={diff.max().item()}, mean={diff.mean().item()}, std={diff.std().item()}")

Expected Behavior / Solution

A quick description of how to implement a parallel algorithm for aggregating running statistics for calculating Pearson correlation is given on the Wikipedia page for variance calculation algorithms. More detailed derivations and analysis can be found in papers by Chan et al. and Schubert et al. (which are cited by the Wikipedia article). I haven't been able to find a new source supporting the current implementation online, and while the current implementation can indeed be simplified into the equations provided by these references, it is a bit more (overly?) complex (and difficult to understand).

Proposed Solution

Below is a simplified implementation which 1) fixes the in-place bug described above, and 2) is more closely aligned with the source algorithms. According to tests on my data it matches the output of the current implementation, and also passes the torchmetrics unit tests. If there is no reason for the current implementation which I am overlooking (e.g., numerical precision, avoiding overflow), would it be worthwhile to replace with this simpler implementation (and also fixing the bug)?

def _final_aggregation(
    means_x: torch.Tensor,
    means_y: torch.Tensor,
    vars_x: torch.Tensor,
    vars_y: torch.Tensor,
    corrs_xy: torch.Tensor,
    nbs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Aggregate the statistics from multiple devices.

    Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_

    """
    if len(means_x) == 1:
        return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    for i in range(1, len(means_x)):
        mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
        # count
        nb = n1 + n2
        # mean_x
        mean_x = (n1 * mx1 + n2 * mx2) / nb
        # mean_y
        mean_y = (n1 * my1 + n2 * my2) / nb
        # intermediates for running variances
        n12_b = n1 * n2 / nb
        delta_x = mx2 - mx1
        delta_y = my2 - my1
        # var_x
        var_x = vx1 + vx2 + n12_b * delta_x ** 2
        # var_y
        var_y = vy1 + vy2 + n12_b * delta_y ** 2
        # corr_xy
        corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

        mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
    return mean_x, mean_y, var_x, var_y, corr_xy, nb

Simple Verification

Below is a simple script verifying that the results of this implementation match that of the current algorithm. From this simple experiment, I find that the results match (up to 1e-8) when using double precision (i.e. torch.float64), but differ at just float precision (i.e. torch.float32). However, upon further investigation, it seems the proposed implementation has better numeric stability, returning results at float precision which match the double precision results up to ~5e-3 (with 64 devices), while the current implementation has maximum errors (relative to the double precision result) on the order of 1e2.

import torch
from torchmetrics.regression.pearson import _final_aggregation

def _final_aggregation_FIX(
    means_x: torch.Tensor,
    means_y: torch.Tensor,
    vars_x: torch.Tensor,
    vars_y: torch.Tensor,
    corrs_xy: torch.Tensor,
    nbs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Aggregate the statistics from multiple devices.

    Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_

    """
    if len(means_x) == 1:
        return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    for i in range(1, len(means_x)):
        mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
        # count
        nb = n1 + n2
        # mean_x
        mean_x = (n1 * mx1 + n2 * mx2) / nb
        # mean_y
        mean_y = (n1 * my1 + n2 * my2) / nb
        # intermediates for running variances
        n12_b = n1 * n2 / nb
        delta_x = mx2 - mx1
        delta_y = my2 - my1
        # var_x
        var_x = vx1 + vx2 + n12_b * delta_x ** 2
        # var_y
        var_y = vy1 + vy2 + n12_b * delta_y ** 2
        # corr_xy
        corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

        mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
    return mean_x, mean_y, var_x, var_y, corr_xy, nb

# Simulate Pearson metric on `N_DEVICES` devices
N_DEVICES = 2
# Simulate Pearson metric with `N_OUTPUTS` outputs
N_OUTPUTS = 100
# Number of timest to repeat verification
N_TESTS = 1000
# Precision of the tensors
DTYPE = torch.float64
# Accuracy tolerance
ATOL = 1e-6

# NOTE: At double precision, both implementations return matching results. At float precision, the results differ, with
#  max relative error of ~10 (pretty significant!). Upon further testing, it appears that the fixed implementation with 
# float precision matches the double precision result up to ~1e-4 (i.e. 1e-4 is the max ABSOLUTE error) for the variances,
# and better for the means/counts. The current implementation however is significantly worse, with max relative error on the
# order of ~10, as noted above.

for n in range(N_TESTS):

    mean_x = torch.randn(N_DEVICES, N_OUTPUTS, dtype=DTYPE)
    mean_y = torch.randn(N_DEVICES, N_OUTPUTS, dtype=DTYPE)
    var_x = torch.randn(N_DEVICES, N_OUTPUTS, dtype=DTYPE)
    var_y = torch.randn(N_DEVICES, N_OUTPUTS, dtype=DTYPE)
    corr_xy = torch.randn(N_DEVICES, N_OUTPUTS, dtype=DTYPE)
    n_total = torch.randint(1, 100, (N_DEVICES, N_OUTPUTS), dtype=DTYPE)

    mean_x_fix, mean_y_fix, var_x_fix, var_y_fix, corr_xy_fix, n_total_fix = _final_aggregation_FIX(mean_x, mean_y, var_x, var_y, corr_xy, n_total)
    mean_x_, mean_y_, var_x_, var_y_, corr_xy_, n_total_ = _final_aggregation(mean_x, mean_y, var_x, var_y, corr_xy, n_total)

    assert torch.allclose(mean_x_fix, mean_x_, atol=ATOL, rtol=0)
    assert torch.allclose(mean_y_fix, mean_y_, atol=ATOL, rtol=0)
    assert torch.allclose(var_x_fix, var_x_, atol=ATOL, rtol=0)
    assert torch.allclose(var_y_fix, var_y_, atol=ATOL, rtol=0)
    assert torch.allclose(corr_xy_fix, corr_xy_, atol=ATOL, rtol=0)
    assert torch.allclose(n_total_fix, n_total_, atol=ATOL, rtol=0)

print(f'All tests passed!')

Environment

  • TorchMetrics version (if build from source, add commit SHA): 1.6.0
  • Python & PyTorch Version (e.g., 1.0): Python 3.10.12, PyTorch 2.6.0.dev20241201+cu124
  • Any other relevant information such as OS (e.g., Linux): Ubuntu
Copy link

github-actions bot commented Jan 3, 2025

Hi! thanks for your contribution!, great first issue!

@alexrgilbert alexrgilbert changed the title Pearson _final_aggregation modifies states in place (and link out of date) Pearson _final_aggregation modifies states in place (+ link out of date) Jan 3, 2025
@alexrgilbert
Copy link
Author

alexrgilbert commented Jan 6, 2025

One more follow up issue—in my recent tests I encountered another bug(?). It occurs when the first two devices don't have any data for certain outputs, such that the sum of observations is zero. In this case, as you iterate over devices, the variance combination will become nan when only the first two devices have been handled...when really they should remain empty.

Below is a simple script to reproduce the issue, using the proposed new _final_aggregation (it also exists with the current implementation):

import torch

# Simulate Pearson metric on `N_DEVICES` devices
N_DEVICES = 4
# Simulate Pearson metric with `N_OUTPUTS` outputs
N_OUTPUTS = 5

def _final_aggregation(
    means_x: torch.Tensor,
    means_y: torch.Tensor,
    vars_x: torch.Tensor,
    vars_y: torch.Tensor,
    corrs_xy: torch.Tensor,
    nbs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Aggregate the statistics from multiple devices.

    Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_

    """
    if len(means_x) == 1:
        return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    for i in range(1, len(means_x)):
        mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
        # count
        nb = n1 + n2
        # mean_x
        mean_x = (n1 * mx1 + n2 * mx2) / nb
        # mean_y
        mean_y = (n1 * my1 + n2 * my2) / nb
        # intermediates for running variances
        n12_b = n1 * n2 / nb
        delta_x = mx2 - mx1
        delta_y = my2 - my1
        # var_x
        var_x = vx1 + vx2 + n12_b * delta_x ** 2
        # var_y
        var_y = vy1 + vy2 + n12_b * delta_y ** 2
        # corr_xy
        corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

        mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
    return mean_x, mean_y, var_x, var_y, corr_xy, nb

mean_x = torch.randn(N_DEVICES, N_OUTPUTS)
mean_y = torch.randn(N_DEVICES, N_OUTPUTS)
var_x = torch.randn(N_DEVICES, N_OUTPUTS)
var_y = torch.randn(N_DEVICES, N_OUTPUTS)
corr_xy = torch.randn(N_DEVICES, N_OUTPUTS)
n_total = torch.randint(1, 100, (N_DEVICES, N_OUTPUTS))

for x in [mean_x, mean_y, var_x, var_y, corr_xy, n_total]:
    x[:2] = 0

# Current
mean_x_, mean_y_, var_x_, var_y_, corr_xy_, n_total_ = _final_aggregation(mean_x, mean_y, var_x, var_y, corr_xy, n_total)
# Expected
mean_x__, mean_y__, var_x__, var_y__, corr_xy__, n_total__ = _final_aggregation(mean_x[2:], mean_y[2:], var_x[2:], var_y[2:], corr_xy[2:], n_total[2:])

print(f'Aggregated states:')
print(f'mean_x: {mean_x_} (expected: {mean_x__})\n')
print(f'mean_y: {mean_y_} (expected: {mean_y__})\n')
print(f'var_x: {var_x_} (expected: {var_x__})\n')
print(f'var_y: {var_y_} (expected: {var_y__})\n')
print(f'corr_xy: {corr_xy_} (expected: {corr_xy__})\n')
print(f'n_total: {n_total_} (expected: {n_total__})\n')

A simple (maybe slightly tacky) way to fix this would be to fill in nb (i.e. the sum of observed counts across devices) with some small epsilon at locations where it is zero. I think this should in fact not impact the result at all, since if the observations are zero, all other states should also be zero...but I may be overlooking something. Here is the modified proposed implementation:

def _final_aggregation(
    means_x: torch.Tensor,
    means_y: torch.Tensor,
    vars_x: torch.Tensor,
    vars_y: torch.Tensor,
    corrs_xy: torch.Tensor,
    nbs: torch.Tensor,
    eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Aggregate the statistics from multiple devices.

    Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_

    NOTE: We use `eps` to avoid division by zero when `n1` and `n2` are both zero. Generally,
     the value of `eps` should not matter, as if `n1` and `n2` are both zero, all the states will
     also be zero.
    """
    if len(means_x) == 1:
        return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    for i in range(1, len(means_x)):
        mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
        # count
       # THIS LINE IS CHANGED!
        nb = torch.where(torch.logical_or(n1, n2), n1 + n2, eps)
        # mean_x
        mean_x = (n1 * mx1 + n2 * mx2) / nb
        # mean_y
        mean_y = (n1 * my1 + n2 * my2) / nb
        # intermediates for running variances
        n12_b = n1 * n2 / nb
        delta_x = mx2 - mx1
        delta_y = my2 - my1
        # var_x
        var_x = vx1 + vx2 + n12_b * delta_x ** 2
        # var_y
        var_y = vy1 + vy2 + n12_b * delta_y ** 2
        # corr_xy
        corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

        mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
    return mean_x, mean_y, var_x, var_y, corr_xy, nb

@Borda Borda added bug / fix Something isn't working v1.6.x help wanted Extra attention is needed labels Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.6.x
Projects
None yet
Development

No branches or pull requests

2 participants