From 37f0219ba11e5941e658ae5378e4ea4cf46f33f0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 14 Feb 2024 07:37:41 +0100 Subject: [PATCH] Clamp variance calculation in certain image metrics (#2378) (cherry picked from commit afae59e4d75ef2dfd0023e3a665e46a0a62833a5) --- CHANGELOG.md | 2 +- src/torchmetrics/functional/image/ssim.py | 5 +++-- src/torchmetrics/functional/image/uqi.py | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e263da4f50..c1a8b5e05b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378)) ## [1.3.1] - 2024-02-12 diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index d0e9d15c6dc..62fd197d1d7 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -154,8 +154,9 @@ def _ssim_update( mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq + # Calculate the variance of the predicted and target images, should be non-negative + sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0) + sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0) sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target.to(dtype) + c2 diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index c52ebe8c16b..ce37d4b24c9 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -102,8 +102,9 @@ def _uqi_compute( mu_target_sq = output_list[1].pow(2) mu_pred_target = output_list[0] * output_list[1] - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq + # Calculate the variance of the predicted and target images, should be non-negative + sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0) + sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0) sigma_pred_target = output_list[4] - mu_pred_target upper = 2 * sigma_pred_target