Skip to content

Commit

Permalink
Fix SSIM memory (#539)
Browse files Browse the repository at this point in the history
* fix

* changelog
  • Loading branch information
SkafteNicki authored Sep 21, 2021
1 parent 68e5636 commit e93bae1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `F1` with `average='macro'` and `ignore_index!=None` ([#495](https://github.com/PyTorchLightning/metrics/pull/495))


- Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539))


## [0.5.1] - 2021-08-30

### Added
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _ssim_compute(

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
output_list = [outputs[x * preds.size(0) : (x + 1) * preds.size(0)] for x in range(len(outputs))]
output_list = outputs.split(preds.shape[0])

mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
Expand Down

0 comments on commit e93bae1

Please sign in to comment.