diff --git a/CHANGELOG.md b/CHANGELOG.md index ad35113a0f6..b7d3e419e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346)) + + - Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348)) diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index cb0bafd5082..39f832905cf 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -91,7 +91,7 @@ def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int else: mask = torch.ones_like(target, dtype=torch.bool) - probs = probs[:, target].diagonal()[mask] + probs = probs[torch.arange(target.numel()), target][mask] total_log_probs = -probs.log().sum() count = mask.sum()