diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index e43a91f2466..77c23a9872c 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -160,8 +160,8 @@ def _ssim_update( sigma_target_sq = output_list[3] - mu_target_sq sigma_pred_target = output_list[4] - mu_pred_target - upper = 2 * sigma_pred_target + c2 - lower = sigma_pred_sq + sigma_target_sq + c2 + upper = 2 * sigma_pred_target.to(dtype) + c2 + lower = (sigma_pred_sq + sigma_target_sq).to(dtype) + c2 ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)