diff --git a/CHANGELOG.md b/CHANGELOG.md index 679c3c93ef2..2b4feca7fdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378)) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index d0e9d15c6dc..62fd197d1d7 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -154,8 +154,9 @@ def _ssim_update( mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq + # Calculate the variance of the predicted and target images, should be non-negative + sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0) + sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0) sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target.to(dtype) + c2 diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index c52ebe8c16b..ce37d4b24c9 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -102,8 +102,9 @@ def _uqi_compute( mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq + # Calculate the variance of the predicted and target images, should be non-negative + sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0) + sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0) sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target