diff --git a/CHANGELOG.md b/CHANGELOG.md index 81db60a16ba..393bdf56fd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,7 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019) ## [1.1.2] - 2023-09-11 diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 3c3d4dfc2c9..eccbc0c9903 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor -from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs +from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs, _check_data_shape_to_weights from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape @@ -32,6 +32,7 @@ def _pearson_corrcoef_update( corr_xy: Tensor, n_prior: Tensor, num_outputs: int, + weights: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Pearson Correlation Coefficient. @@ -47,34 +48,57 @@ def _pearson_corrcoef_update( corr_xy: current covariance estimate between x and y tensor n_prior: current number of observed observations num_outputs: Number of outputs in multioutput setting + weights: weights associated with scores """ # Data checking _check_same_shape(preds, target) _check_data_shape_to_num_outputs(preds, target, num_outputs) - n_obs = preds.shape[0] + if weights is not None: + _check_data_shape_to_weights(preds, weights) + + n_obs = preds.shape[0] if weights is None else weights.sum() cond = n_prior.mean() > 0 or n_obs == 1 if cond: - mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs) - my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs) + if weights is None: + mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs) + my_new = (n_prior * mean_y + target.sum(0)) / (n_prior + n_obs) + else: + mx_new = (n_prior * mean_x + torch.matmul(weights, preds)) / (n_prior + n_obs) + my_new = (n_prior * mean_y + torch.matmul(weights, target)) / (n_prior + n_obs) else: - mx_new = preds.mean(0) - my_new = target.mean(0) + if weights is None: + mx_new = preds.mean(0) + my_new = target.mean(0) + else: + mx_new = torch.matmul(weights, preds) / weights.sum() + my_new = torch.matmul(weights, target) / weights.sum() n_prior += n_obs + # Calculate variances if cond: - var_x += ((preds - mx_new) * (preds - mean_x)).sum(0) - var_y += ((target - my_new) * (target - mean_y)).sum(0) + if weights is None: + var_x += ((preds - mx_new) * (preds - mean_x)).sum(0) + var_y += ((target - my_new) * (target - mean_y)).sum(0) + else: + var_x += torch.matmul(weights, (preds - mx_new) * (preds - mean_x)) + var_y += torch.matmul(weights, (preds - my_new) * (preds - mean_y)) + else: + if weights is None: + var_x += preds.var(0) * (n_obs - 1) + var_y += target.var(0) * (n_obs - 1) + else: + var_x += torch.matmul(weights, (preds - mx_new) ** 2) + var_y += torch.matmul(weights, (target - my_new) ** 2) + + if weights is None: + corr_xy += ((preds - mx_new) * (target - my_new)).sum(0) else: - var_x += preds.var(0) * (n_obs - 1) - var_y += target.var(0) * (n_obs - 1) - corr_xy += ((preds - mx_new) * (target - mean_y)).sum(0) - mean_x = mx_new - mean_y = my_new + corr_xy += torch.matmul(weights, (preds - mx_new) * (target - my_new)) - return mean_x, mean_y, var_x, var_y, corr_xy, n_prior + return mx_new, my_new, var_x, var_y, corr_xy, n_prior def _pearson_corrcoef_compute( @@ -92,9 +116,6 @@ def _pearson_corrcoef_compute( nb: number of observations """ - var_x /= nb - 1 - var_y /= nb - 1 - corr_xy /= nb - 1 # if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16 # on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"): @@ -114,12 +135,16 @@ def _pearson_corrcoef_compute( return torch.clamp(corrcoef, -1.0, 1.0) -def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: +def pearson_corrcoef(preds: Tensor, target: Tensor, weights: Optional[Tensor] = None) -> Tensor: """Compute pearson correlation coefficient. Args: - preds: estimated scores - target: ground truth scores + preds: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs) + Estimated scores + target: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs) + Ground truth scores + weights: torch.Tensor of shape (n_samples,), default=None + Sample weights Example (single output regression): >>> from torchmetrics.functional.regression import pearson_corrcoef @@ -128,6 +153,14 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: >>> pearson_corrcoef(preds, target) tensor(0.9849) + Example (weighted single output regression): + >>> from torchmetrics.functional.regression import pearson_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> weights = torch.tensor([2.5, 0.0, 2, 8]) + >>> pearson_corrcoef(preds, target, weights) + tensor(0.9849) + Example (multi output regression): >>> from torchmetrics.functional.regression import pearson_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) @@ -135,12 +168,29 @@ def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor: >>> pearson_corrcoef(preds, target) tensor([1., 1.]) + Example (weighted multiple output regression): + >>> from torchmetrics.functional.regression import pearson_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> weights = torch.tensor([2.5, 0.0, 2, 8]) + >>> pearson_corrcoef(preds, target, weights) + tensor(0.9849) + """ d = preds.shape[1] if preds.ndim == 2 else 1 _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone() var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone() _, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update( - preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb, num_outputs=1 if preds.ndim == 1 else preds.shape[-1] + preds, + target, + mean_x, + mean_y, + var_x, + var_y, + corr_xy, + nb, + num_outputs=1 if preds.ndim == 1 else preds.shape[-1], + weights=weights, ) return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb) diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index 59612927f26..d1c59a03cfd 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -41,3 +41,25 @@ def _check_data_shape_to_num_outputs( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" f" and {preds.shape[1]}." ) + + +def _check_data_shape_to_weights(preds: Tensor, weights: Tensor) -> None: + """Check that the predictions and weights have the correct shape, else raise error. + + This test assumes that the prediction and target tensors have been confirmed to have the same shape. + It further assumes that the `pred` is either a 1- or 2-dimensional tensor. + + """ + if preds.ndim == 1 and preds.shape != weights.shape: + raise ValueError( + f"Expected `preds.shape` to equal to `weights.shape`, but got {preds.shape} and {weights.shape}." + ) + elif preds.ndim == 2: + if weights.ndim == 1 and preds.shape[0] != len(weights): + raise ValueError( + f"Expected `preds.shape[0]` to equal to `len(weights)` but got {preds.shape[0]} and {len(weights)}." + ) + if weights.ndim == 2 and preds.shape != weights.shape: + raise ValueError( + f"Expected `preds.shape` to equal to `weights.shape`, but got {preds.shape} and {weights.shape}." + ) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index a25fa72ff7e..5506ed304b5 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -141,7 +141,7 @@ def __init__( self.add_state("corr_xy", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) self.add_state("n_total", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor, weights: Optional[Tensor] = None) -> None: """Update state with predictions and targets.""" self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total = _pearson_corrcoef_update( preds, @@ -153,6 +153,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.corr_xy, self.n_total, self.num_outputs, + weights=weights, ) def compute(self) -> Tensor: diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 3740e7bf335..8242e0aa3c3 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -270,7 +270,7 @@ def _functional_test( extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} tm_result = metric(preds[i], target[i], **extra_kwargs) extra_kwargs = { - k: v.cpu() if isinstance(v, Tensor) else v + k: v[i].cpu() if isinstance(v, Tensor) else v for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() } ref_result = reference_metric( diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 90c7df76b92..62ffa2a16cc 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -14,9 +14,11 @@ from collections import namedtuple from functools import partial +import numpy as np import pytest import torch from scipy.stats import pearsonr +from torch import Tensor from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation @@ -28,6 +30,7 @@ Input = namedtuple("Input", ["preds", "target"]) + _single_target_inputs1 = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), @@ -50,12 +53,34 @@ ) +def _ref_metric(preds, target, weights=None): + if weights is None: + return _scipy_pearson(preds, target) + return _weighted_pearson(preds, target, weights) + + def _scipy_pearson(preds, target): if preds.ndim == 2: return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] return pearsonr(target.numpy(), preds.numpy())[0] +def _weighted_pearson(preds, target, weights): + preds = preds.numpy() if isinstance(preds, Tensor) else preds + target = target.numpy() if isinstance(target, Tensor) else target + weights = weights.numpy() if isinstance(weights, Tensor) else weights + + if preds.ndim == 2: + return [_weighted_pearson(p, t, weights) for p, t in zip(preds.T, target.T)] + + mx = (weights * preds).sum() / weights.sum() + my = (weights * target).sum() / weights.sum() + var_x = (weights * (preds - mx) ** 2).sum() + var_y = (weights * (target - my) ** 2).sum() + cov_xy = (weights * (preds - mx) * (target - my)).sum() + return cov_xy / np.sqrt(var_x * var_y) + + @pytest.mark.parametrize( "preds, target", [ @@ -70,9 +95,16 @@ class TestPearsonCorrCoef(MetricTester): atol = 1e-3 + @pytest.mark.parametrize( + "kwargs_update", + [ + pytest.param({}, id="None weights"), + pytest.param({"weights": torch.rand(NUM_BATCHES, BATCH_SIZE)}, id="tensor weights"), + ], + ) @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) - def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): + def test_pearson_corrcoef(self, preds, target, kwargs_update, compute_on_cpu, ddp): """Test class implementation of metric.""" num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_class_metric_test( @@ -80,14 +112,26 @@ def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): preds=preds, target=target, metric_class=PearsonCorrCoef, - reference_metric=_scipy_pearson, + reference_metric=_ref_metric, metric_args={"num_outputs": num_outputs, "compute_on_cpu": compute_on_cpu}, + weights=kwargs_update.get("weights", None), ) - def test_pearson_corrcoef_functional(self, preds, target): + @pytest.mark.parametrize( + "kwargs_update", + [ + pytest.param({}, id="None weights"), + pytest.param({"weights": torch.rand(NUM_BATCHES, BATCH_SIZE)}, id="tensor weights"), + ], + ) + def test_pearson_corrcoef_functional(self, preds, target, kwargs_update): """Test functional implementation of metric.""" self.run_functional_metric_test( - preds=preds, target=target, metric_functional=pearson_corrcoef, reference_metric=_scipy_pearson + preds=preds, + target=target, + metric_functional=pearson_corrcoef, + reference_metric=_ref_metric, + weights=kwargs_update.get("weights", None), ) def test_pearson_corrcoef_differentiability(self, preds, target): @@ -100,7 +144,7 @@ def test_pearson_corrcoef_differentiability(self, preds, target): metric_functional=pearson_corrcoef, ) - def test_pearson_corrcoef_half_cpu(self, preds, target): + def test_pearson_corrcoef_half_cpu(self, preds, target, metric_args): """Test dtype support of the metric on CPU.""" num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_precision_test_cpu(preds, target, partial(PearsonCorrCoef, num_outputs=num_outputs), pearson_corrcoef)