From c0d2d3ab7221f8804ba55baf3945637d68485cc1 Mon Sep 17 00:00:00 2001 From: niklas Date: Thu, 8 Feb 2024 03:35:50 +0100 Subject: [PATCH] Memory optimization of perplexity metric (#2346) * reduce memory footprint when computing perplexity --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/text/perplexity.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad35113a0f6..b7d3e419e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346)) + + - Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348)) diff --git a/src/torchmetrics/functional/text/perplexity.py b/src/torchmetrics/functional/text/perplexity.py index cb0bafd5082..39f832905cf 100644 --- a/src/torchmetrics/functional/text/perplexity.py +++ b/src/torchmetrics/functional/text/perplexity.py @@ -91,7 +91,7 @@ def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int else: mask = torch.ones_like(target, dtype=torch.bool) - probs = probs[:, target].diagonal()[mask] + probs = probs[torch.arange(target.numel()), target][mask] total_log_probs = -probs.log().sum() count = mask.sum()