Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High memory usage of Perplexity metric #2337

Closed
nsmlzl opened this issue Jan 30, 2024 · 2 comments · Fixed by #2346
Closed

High memory usage of Perplexity metric #2337

nsmlzl opened this issue Jan 30, 2024 · 2 comments · Fixed by #2346
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x

Comments

@nsmlzl
Copy link
Contributor

nsmlzl commented Jan 30, 2024

🐛 Bug

I ran out of memory (GPU) when computing the perplexity metric and would like to propose a small optimization to decrease its memory utilization.

To Reproduce

For instance, when running the following code PyTorch tries to allocate 1024 GB of GPU memory on my system.

from torchmetrics.text import Perplexity
import torch

gen = torch.manual_seed(42)
preds = torch.rand(512, 1024, 12, generator=gen).cuda()
target = torch.randint(12, (512, 1024), generator=gen).cuda()

perp = Perplexity().cuda()
print(perp(preds, target))

Memory Inefficiency

I think the inefficiency is in this line:

probs = probs[:, target].diagonal()[mask]

probs[:, target] results in a large temporary tensor with (512*1024)^2 elements. Afterwards only the diagonal values are used.

Potential Solution

In contrast

probs = probs[torch.arange(target.numel()), target][mask]

would only require memory of the size of target.

Would you consider accepting a pull request with this optimization? Or was the previous implementation chosen for another reason?

Environment

  • TorchMetrics v1.2.1 (installed with pip) and Master branch.
  • Python 3.10.12
  • Pytorch 2.2.0
  • CUDA 12.1
@nsmlzl nsmlzl added bug / fix Something isn't working help wanted Extra attention is needed labels Jan 30, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added the v1.2.x label Jan 31, 2024
nsmlzl added a commit to nsmlzl/mamborosDNA that referenced this issue Jan 31, 2024
required an optimization in torchmetrics;
see Lightning-AI/torchmetrics#2337
@nsmlzl
Copy link
Contributor Author

nsmlzl commented Feb 2, 2024

Just created PR #2346 with the (small) change. Feel free to merge, when you like it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants