Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Performance Regression] GPU memory increase for training and inference models #18280

Open
karan6181 opened this issue May 11, 2020 · 11 comments
Assignees

Comments

@karan6181
Copy link
Contributor

Description

  • There is an MXNet nightly benchmark which runs CV and NLP models on MXNet Nightly pip wheel and report the metrics and it showed a performance regression on GPU Memory.
  • After bisecting the PRs, the PR due to which there is an increase in GPU memory of 120-130 MB happen is 17767. I ran the SSD training with and without the PR commit and it showed around 120MB increase in GPU Memory. I haven't ran all those models personally but the reproducing script is basic and it is applicable to all those models.
  • List of models that affected are as below. The GPU Memory increase data we got from our internal benchmarking system:
model GPU Memory (From:to)
VGG16_training 9.63k to 9.78k
SSD_training 4.75k to 4.91k
LSTM_inference: Gluon and Module G: 1.55k to 1.71
M: 1.47k to 1.64k
Caffenet_inference: Gluon and Module G: 2.17k to 2.31k
M: 2.0k to 2.16k
YoloV3_GPU: Gluon and Module G: 1.87k to 2.03k
M: 1.84k to 2.0k
Resnet50_v2_FP16_inference_GPU: Gluon and Module G: 1.58k to 1.74k
M: 1.55k to 1.71k
Resnet50_v2_inference_GPU: Gluon and Module G: 1.55k to 1.71k
M: 1.48k to 1.64k
Resnet152_v2_inference_GPU: Gluon and Module G: 1.92k to 2.08k
M: 1.86k to 2.02k
Inception_Inference_GPU: Gluon and Module G: 1.43k to 1.59k
M: 1.40 to 1.56k
SSD_inference_GPU: Gluon and Module G: 1.65k to 1.81k
M: 1.54k to 1.70k
A3C_inference_GPU: Gluon and Module G: 1.32k to 1.48k
M: 1.31k to 1.47k
word_language_model_hybrid_p3.16_training 2.25k to 2.43k
Mobile_pose_training 2.38k to 2.54k

To Reproduce

  • Run the below lines of code and monitor GPU usage using nvidia-smi command.
import mxnet as mx
a = mx.nd.zeros((1,), ctx=mx.gpu())

Output:

MXNet build from source:

Instance: p3.16xLarge

Without PR [17767] Commit id: f882de0c7ecd6ff1f0fdba492865afc6d7e29271
GPU usage: gpu(0): 1279MiB / 16160MiB

With PR [17767] Commit id: 5542d03695b4a2589afb88acf128d4ba8ac94d0d
GPU usage: gpu(0): 1407MiB / 16160MiB
  • I am tagging the PR author here: @ptrendx
  • Thanks @ptrendx for sharing the simple reproducible script.
@karan6181 karan6181 added the Bug label May 11, 2020
@karan6181
Copy link
Contributor Author

@mxnet-label-bot update [Performance]

@ptrendx
Copy link
Member

ptrendx commented May 11, 2020

Moving the details from the offline discussion.

The problem here is that when loading library containing GPU kernels, those kernels are stored in GPU memory. So the more kernels we have in MXNet library, 2 things happen:

To fully resolve this, I do not believe it is feasible to rely on template instantiation during compilation to generate those simple kernels, but instead to move towards runtime compilation. This has a series of advantages:

  • only the kernels actually used by the user are compiled and loaded, so there is more GPU memory left for the actual DL training
  • compilation time of the library would be greatly reduced
  • size of the resulting library would be greatly reduced
  • potentially gives more building blocks for fusion application and makes those efforts easier

This is not 100% silver bullet though, as it also has its downsides:

  • the cost of compilation of those kernels is pushed to the end user in every launch of MXNet (although using caching like Speed fused_op compilation by caching ptx and jit-compiled device functions #16783 means that each kernel needs to be compiled only once) - that said, each such compilation takes a few ms, so that additional cost should be well below a second for majority of users and could be partially offset by reduction in cost needed to load the library to GPU
  • that cost would be fully shown during CI runs, as they, contrary to actual users, use all of the library functionality, so would need to compile everything
  • potential duplication of code of math functions, as using the same code both as code and as text included for the runtime compilation is not really possible without the external tools (see e.g. https://stackoverflow.com/questions/410980/include-a-text-file-in-a-c-program-as-a-char)
  • this requires making the libnvrtc.so an essential component of CUDA builds, which means a need to dynamically load it and libcuda.so (there is already a request to do that: dynamic load libnvrtc.so #17858)

I believe that the benefits outweigh the downsides, especially with the push for DeepNumpy and the need to support different type combinations, which prompted already to disable some of them in Windows builds because of those issues (see e.g. https://github.com/apache/incubator-mxnet/blob/f00b9ab5b4410a91a8f6581da696a92f85fbccf6/src/operator/numpy/np_elemwise_broadcast_op.cu#L32-L41)

@apeforest @leezu @szha @eric-haibin-lin @sxjscience What are your thoughts about this?

@sxjscience
Copy link
Member

I agree with using JIT to resolve the binary size issue. @ptrendx @karan6181 Accoding to the reproducible example, does it mean that there is a 128 MB GPU memory increase for just storing the GPU codes?

@ptrendx
Copy link
Member

ptrendx commented May 12, 2020

Yes, that is right - in total we lose 1407 MB of GPU memory just for loading the library (it is not just the kernel code being loaded, in this there is I believe ~400 MB of context, but majority is the kernel code).

@leezu
Copy link
Contributor

leezu commented May 12, 2020

@ptrendx would your proposal be in scope for 2.0? One further advantage is that by performing compilation at runtime, we may be able to avoid libmxnet.so to become subject to the CUDA EULA as long as nvrtc.h could be licensed under a compatible license

@ptrendx
Copy link
Member

ptrendx commented May 12, 2020

@leezu That would require doing everything CUDA-related as RTC, whereas my proposal was currently limited only to portion of the kernels.

I would like to start working on this RTC approach ~now, so yes, it is definitely in scope for 2.0 :-).

@ptrendx
Copy link
Member

ptrendx commented May 12, 2020

@leezu @apeforest If anybody wants to help with the build side of things for the dynamic loading of libcuda.so and libnvrtc.so (especially CMake, since I have pretty much 0 experience with it), it would be greatly appreciated.

@leezu
Copy link
Contributor

leezu commented May 12, 2020

On the CMake side: What else do you expect is required besides removing the declaration that libmxnet depends on libcuda and libnvrtc? It's not clear to me yet why further changes would be needed. In either case, I'm happy to help on the build side.

@ptrendx
Copy link
Member

ptrendx commented May 12, 2020

For the dynamic loading of libnvrtc that would be it I think, I was also thinking about something that would prevent code duplication (something like some answers to this SO question: https://stackoverflow.com/questions/410980/include-a-text-file-in-a-c-program-as-a-char), since that would require generating some files during build.

@leezu
Copy link
Contributor

leezu commented May 13, 2020

I see. To run xxd, we can use https://cmake.org/cmake/help/latest/command/add_custom_command.html and add the OUTPUT files to the set of source files that cmake compiles for the libmxnet target.

@ptrendx
Copy link
Member

ptrendx commented May 14, 2020

xxd is Linux only I believe, so the better approach is actually to just be able to generate the file with R"delimiter( and )delimiter" around the code, which should be much easier to do cross-platform.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants