Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchvision decode_jpeg memory leak #4378

Open
igrekun opened this issue Sep 7, 2021 · 26 comments
Open

Torchvision decode_jpeg memory leak #4378

igrekun opened this issue Sep 7, 2021 · 26 comments

Comments

@igrekun
Copy link

igrekun commented Sep 7, 2021

🐛 Describe the bug

nvJPEG leaks memory and fails with OOM after ~1-2k images.

import torch
from torchvision.io import read_file, decode_jpeg

for i in range(1000): # increase to your liking till gpu OOMs (:
    img_u8 = read_file('lena.jpg')
    img_nv = decode_jpeg(img_u8, device='cuda')

Probably related to first response to #3848

RuntimeError: nvjpegDecode failed: 5

is exactly the message you get after OOM.

Versions

PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 11.1.0
Clang version: 12.0.1
CMake version: version 3.21.1
Libc version: glibc-2.33

Python version: 3.8.7 (default, Jan 19 2021, 18:48:37) [GCC 10.2.0] (64-bit runtime)
Python platform: Linux-5.13.8-arch1-1-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti
GPU 2: NVIDIA GeForce GTX 1080

Nvidia driver version: 470.57.02
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.2.2
/usr/lib/libcudnn_adv_infer.so.8.2.2
/usr/lib/libcudnn_adv_train.so.8.2.2
/usr/lib/libcudnn_cnn_infer.so.8.2.2
/usr/lib/libcudnn_cnn_train.so.8.2.2
/usr/lib/libcudnn_ops_infer.so.8.2.2
/usr/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] adabelief-pytorch==0.2.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.19.5
[pip3] pytorch-lightning==1.4.5
[pip3] torch==1.9.0+cu111
[pip3] torchaudio==0.9.0
[pip3] torchfile==0.1.0
[pip3] torchmetrics==0.4.1
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect

@vadimkantorov
Copy link

vadimkantorov commented Sep 21, 2021

@fmassa Is general torchvision.io.read_image affected? In general, is torchvision.io.read_image ready for being used in dataset classes instead of PIL / OpenCV for reading images from disk? Is it fast?

@Scass0807
Copy link

@NicolasHug @fmassa Also having this issue. Tried loading images on loop using decode_jpeg directly to GPU. Was able to run a number of images but GPU memory keeps going until it runs out and fails. Memory is fine on CPU. Was wondering if there is a timeline for when this will be fixed. Hoping it will be fixed ASAP as loading directly to GPU is crucial to getting speeds fast enough to run in real time.

@rydenisbak
Copy link

same problem:
ubuntu 20.04
Driver Version: 470.82.00
CUDA Version: 11.4
torch version '1.10.2+cu113'
torchvision version '0.11.3+cu113'

@NicolasHug
Copy link
Member

Thanks all for the reports.

I took a look a this today. I can reproduce the leak. I do see the memory usage going up constantly with nvidia-smi. But printing the allocated and reserved cuda memory with torch.cuda.memory_stats() shows no leak: I always get

CUDA rsvrd = 12.583 MB
CUDA alloc = 12.583 MB

I thought the leak might come from the fact that we don't free the nvjpeg handle (we literally leak it for convenience)

namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}

but that's not the case: putting back the handle within the function and properly destroying it with nvjpegDestroy() still leads to a leak.

I don't see the leak anymore when commenting out the nvjpegDecode() call, so it happens somewhere inside. I don't understand why nvjpegDecode() has to allocate anything though, because the output cuda tensor's memory was already allocated prior to the call.

I don't know whether that's actually a bug from nvjpeg, or if there's something else going on. Either way, I don't understand. nvjpeg allows to pass custom device memory allocators, perhaps there is something to do there.
I'll try to look into that, but meanwhile if anyone has any idea of what's going on, I would appreciate any help.

Cheers

@NicolasHug
Copy link
Member

NicolasHug commented Feb 11, 2022

nvjpeg allows to pass custom device memory allocators, perhaps there is something to do there.
I'll try to look into that

Update: this still leaks 🥲

int dev_malloc(void **p, size_t s) {
    *p = c10::cuda::CUDACachingAllocator::raw_alloc(s);
    return 0;
  }

int dev_free(void *p) {
    c10::cuda::CUDACachingAllocator::raw_delete(p);
    return 0;
}

...

  nvjpegDevAllocator_t dev_allocator = {&dev_malloc, &dev_free};
  nvjpegStatus_t status = nvjpegCreateEx(NVJPEG_BACKEND_DEFAULT, &dev_allocator,
                                         NULL, NVJPEG_FLAGS_DEFAULT,  &nvjpeg_handle);

source for ref: https://github.com/NicolasHug/vision/blob/3f057677f5676bf12c42571f6bf72686dd9c61f2/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp#L28

@NicolasHug
Copy link
Member

NicolasHug commented Feb 18, 2022

I had a chance to look at this more: this is an nvjpeg bug. Unfortunately I'm not sure we can do much about it.

It was fixed with CUDA 11.6 but I'm still observing the leak with 11.0 - 11.5.

A temporary fix for linux users is to download the 11.6 nvjpeg.so e.g. from here and to tell ld to use it instead of whatever you currently have installed (using LD_LIBRARY_PATH, or LD_PRELOAD, or something else)

@rydenisbak
Copy link

Hello @NicolasHug thanks for answer! I reinstalled CUDA, now I have this version
NVIDIA-SMI 510.47.03
Driver Version: 510.47.03
CUDA Version: 11.6
torch version '1.10.2+cu113'
torchvision version '0.11.3+cu113'

But problem does not disappear. Should I rebuild torch with cuda 11.6 from source?

@NicolasHug
Copy link
Member

What does ldd <path_to_your_python_env>/site-packages/torchvision/image.so say regarding libnvjpeg? It's possible that it's still being linked to an 11.3 version, or that your .so file isn't as up to date as the I linked to above (11.6.0.55-1).

@Scass0807
Copy link

@NicolasHug Mine is showing /site-packages/torchvision/../torchvision.libs/libnvjpeg.90286a3c.so.11 How do I fix this to use system cuda?

@Scass0807
Copy link

@rydenisbak Did you figure this out?

@NicolasHug
Copy link
Member

@Scass0807 if the path is coming from /site-packages... I assume it's the one from the official torchvision binaries, and so it's probably from CUDA <= 11.3. The bug was fixed starting from 11.6 from what I could tell. Have you tried the workaround suggested above in #4378 (comment) ?

@Scass0807
Copy link

Scass0807 commented Mar 10, 2022

@NicolasHug
There's two of them.

usr/local/cuda-11.6/lib64/libnvjpeg.so.11 (0x00007fc51ba65000)
libpng16.7f72a3c5.so.16 => /home/steven/.local/lib/python3.8/site-packages/torchvision/../torchvision.libs/libpng16.7f72a3c5.so.16 (0x00007fc51b82e000)
libjpeg.ceea7512.so.62 => /home/steven/.local/lib/python3.8/site-packages/torchvision/../torchvision.libs/libjpeg.ceea7512.so.62 (0x00007fc51b5d9000)
libnvjpeg.90286a3c.so.11 => /home/steven/.local/lib/python3.8/site-packages/torchvision/../torchvision.libs/libnvjpeg.90286a3c.so.11 (0x00007fc51aee8000)

@Scass0807
Copy link

@NicolasHug I added a symlink libnvjpeg.90286a3c.so.11 -> /usr/local/cuda-11.6/lib64/libnvjpeg.so.11. Now there is only 1 nvjpeg but the memory leak persists. I wonder if this is because even though I am using 11.6 my driver version is 495.23 which is technically for 11.5. I am using GCP Compute Engine and unfortunately they do not yet support 511.

@tp-nan
Copy link

tp-nan commented Mar 11, 2022

Hi, @NicolasHug, Would you mind telling where do you get this information? i could not find it in the cuda 11.6 release note. And I cannot reproduce this memory leak with cuda 10.2 (docker pull nvcr.io/nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04) . It would be great if there is some more information.

@NicolasHug
Copy link
Member

Would you mind telling where do you get this information?

I basically tried all versions I could find from https://pkgs.org/search/?q=libnvjpeg-devel

@Scass0807
Copy link

@NicolasHug should installing 11.6 and using the one that CUDA was built with it work? Do I have to install the RPM I’m on Ubuntu? Based on the LDD results from above I’m not sure if there’s anything else I can do?

@tp-nan
Copy link

tp-nan commented Mar 31, 2022

Would you mind telling where do you get this information?

I basically tried all versions I could find from https://pkgs.org/search/?q=libnvjpeg-devel

It seems that there is a small multithread confusion here :

std::once_flag nvjpeg_handle_creation_flag;

the nvjpeg_handle_creation_flag should be global, not local.

@Kubci
Copy link

Kubci commented Jun 3, 2022

Hi,

I am using:
pytorch 1.11.0+cu113
ubuntu 20.04 LTS
python 3.9

I did replace libnvjpeg.90286a3c.so.11 with .so from cuda 11.6.2. However the memory keeps growing indefinitely.
image

@Kubci
Copy link

Kubci commented Jun 3, 2022

It does not work with cuda 11.7 libnvjpeg either. But same behavior is observed when using numpy.frombuffer. Now I have to decode jpegs on a cpu like a pheasant :'(

@henry-zwart
Copy link

Also seeing this issue on CUDA 11.6 (running in a docker container):

Traceback (most recent call last):
File "/code/scripts/predict-noncapture-hauls.py", line 158, in
run()
File "/usr/local/lib/python3.9/dist-packages/click/core.py", line 1130, in call
return self.main(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/click/core.py", line 1055, in main
rv = self.invoke(ctx)
File "/usr/local/lib/python3.9/dist-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/usr/local/lib/python3.9/dist-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/code/scripts/predict-noncapture-hauls.py", line 144, in run
predictions, metadata = evaluate_event_yolov5(
File "/code/scripts/predict-noncapture-hauls.py", line 75, in evaluate_event_yolov5
for batch in dataloader:
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 530, in next
data = self._next_data()
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1224, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1250, in _process_data
data.reraise()
File "/usr/local/lib/python3.9/dist-packages/torch/_utils.py", line 457, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/code/scripts/predict-noncapture-hauls.py", line 39, in getitem
im_nv = decode_jpeg(im_u8, device="cuda").float() / 255
File "/usr/local/lib/python3.9/dist-packages/torchvision/io/image.py", line 155, in decode_jpeg
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
RuntimeError: nvjpegDecode failed: 5

@dschoerk
Copy link

I just checked if this was fixed in pytorch nightly with cuda 11.6, but i'm still experiencing a memory leak.

python -m pip install torch torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu116

@yupbank
Copy link

yupbank commented Oct 12, 2022

same ^

@Inkorak
Copy link

Inkorak commented Oct 19, 2022

Yes, there are still leaks, even on cuda 11.6

@Amadeus-AI
Copy link

Amadeus-AI commented Dec 10, 2022

Memory leaks on torchvision-0.14.0+cu117 (torchvision-0.14.0%2Bcu117-cp37-cp37m-win_amd64.whl).
When will this be fixed?

easy to reproduce:

for i in range(10000):
    torchvision.io.decode_jpeg(torch.frombuffer(jpeg_bytes,dtype=torch.uint8), device='cuda')

Memory leaks didn't happen when using pynvjpeg 0.0.13, which seems to be built with cuda 10.2

nj = NvJpeg()
nj.decode(jpeg_bytes)

@sh524shin
Copy link

Is there anyone who solve this problem?? I also tried to use pynvjpeg, it is slower than torchvision.io.decode_jpeg and also at last some error msg pops up like this : what() : memory allocator error aborted (core dumped)..

@langren666
Copy link

It seems that this problem has been solved. My environment is as follows
System: Ubuntu22.04
NVIDIA-SMI: 535.86.05
Driver Version: 535.86.05
CUDA Version: 12.2
torch version: 2.0.1+cu118
torchvision version: 0.15.2+cu118

finally, after waiting for over a year. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests