Skip to content

Commit

Permalink
Memory optimization of perplexity metric (#2346)
Browse files Browse the repository at this point in the history
* reduce memory footprint when computing perplexity

(cherry picked from commit c0d2d3a)
  • Loading branch information
nsmlzl authored and Borda committed Feb 12, 2024
1 parent b945fd3 commit ac1b139
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346))


- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int
else:
mask = torch.ones_like(target, dtype=torch.bool)

probs = probs[:, target].diagonal()[mask]
probs = probs[torch.arange(target.numel()), target][mask]
total_log_probs = -probs.log().sum()
count = mask.sum()

Expand Down

0 comments on commit ac1b139

Please sign in to comment.