-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch 2.1 GPU access not seen / managed by nvshare #11
Comments
@t-arsicaud-catie Can you re-run with the environment variable (In other words, run with |
Hi, In fact, I don't get any debug output in the pytorch app terminal, with And in the
While when i run the code with
...in the app terminal, and :
...in the In both cases, |
This is weird. We need to verify if the Pytorch 2.1.0 application is indeed making the CUDA calls that Can you run You can do this with the Then, paste the logs here. |
I am not used to using gdb but I suppose this what your asking for : (running the script with with breakpoints on
So no hits to With previous versions of torch, calls to For the tests, I just switch from one virtual environment to an other, keeping the |
Good job with However, there is a little problem.
You mistakenly added a breakpoint for Could you rerun the test with a breakpoint for
If you redo the initial test, you'll notice that |
Thank you for your answer ans sorry for the inconvenience. Here is the output of
With the same breakpoints and
And yes of course you are right, when I put a |
Hmm, this is strange... Let's take a step back and verify that the dynamic linker/loader indeed links Could you run
In this case we want to examine:
|
Hi, Thank you for your answer. I'm quite confused as the output of The beginning of the log file contains :
Is there a way to filter / extract otherwise relevant information ? I've tried something like
and
|
Hmmm, I didn't predict it would be this big. The problem is with the To avoid having a single, huge log file, can you split the process in two steps? For the Pytorch 2.1.0 and 2.0.1 applications:
We're getting closer! |
Hi, The Refering to
Also,
For bindings,
and
|
For comparison, when doing the same tests with the torch 2.0.1 app, I get :
|
Thanks for taking the time to run these tests. I'd like to take a look at the full logs (both for libs and bindings), so if you could mail them to me, or upload them to a public place, I'll happily take a look. |
Also, I'd like you to rerun the
I noticed that Pytorch 2.1.0 (from PyPI -- the one you have installed) comes with CUDA 12.x, while Pytorch 2.0.1 comes with CUDA 11.x. CUDA 12.0 introduced a new function, To verify this, could you uninstall Pytorch 2.1.0 and re-install it with CUDA 11.8, following the official instructions [1]? Then, rerun the Pytorch 2.1.0 example and my prediction is that it will work. [1] https://pytorch.org/get-started/previous-versions/#linux-and-windows-1 |
Yes, you are right ! In a cuda 12.2 environment, with torch installed with In the same cuda 12.2 environement, it is not when torch has been installed with only I've collected the full logs which you requested for the cuda 12.2 / cuda 12.2 scenario, I'll send you them by email. For the gdb part, the symbols which are called in this scenario are, as you expected, |
Great job! In order to support CUDA >=12 applications, we must also hook I will prepare (and merge) a PR tackling this when I get some time. In the meantime, you can use the |
Meanwhile, the definition of |
Currently, we use According the CUDA documentation, it is the only function that applications must necessarily call before using a GPU. In the case of applications that use the CUDA Runtime API, it internally calls
Therefore, Regarding the differences between Do you want to point out something specfific about the approach we should take regarding the last argument of Perhaps you can experiment a bit with a CUDA 12.x Runtime API application and see how it uses the function. |
@grgalex It means that Therefore,
|
You are right. I had missed that! By the way, do you want to prepare and send a PR for this? I'm kinda busy at the moment, so I would really appreciate any help! The suggested changes are (correct me if I'm wrong):
|
OK, I will submit a PR for this :) |
CUDA 12.0 introduced a new function, cuGetProcAddress_v2(), which Runtime API applications call instead of cuGetProcAddress() in order to obtain Driver API symbols. To maintain compatibility with CUDA >=12.0 applications, add a hook for cuGetProcAddress_v2(). Closes grgalex#11 Signed-off-by: Xinyuan Lyu <[email protected]>
CUDA 12.0 introduced a new function, cuGetProcAddress_v2(), which Runtime API applications call instead of cuGetProcAddress() in order to obtain Driver API symbols. To maintain compatibility with CUDA >=12.0 applications, add a hook for cuGetProcAddress_v2(). Closes #11 Signed-off-by: Xinyuan Lyu <[email protected]> Reviewed-by: George Alexopoulos <[email protected]>
We just merged support for CUDA 12. Feel free to deploy from the |
Hi, Sorry, I was unavailable for a while and could not test until now. It's done, and I confirm that it works well, at least with the latest versions of pytorch / cuda. thank you both for your work and the improvement of nvshare ! |
Hi,
I recently discovered that pytorch code such as the following :
which is, at execution time, registered and managed by
nvshare
withtorch==1.13.1
andtorch==2.0.1
, is not withtorch==2.1
.Code run as expected, accessing to the GPU defined in
CUDA_VISIBLE_DEVICES
, but directly, bypassing the controls made bynvshare
.My test environment is the following :
nvshare
compiled and installed following the recommendations in the READMECUDA_VISIBLE_DEVICES
andLD_PRELOAD
correctly setAny idea on the reason why, and is there a way to prevent this when
CUDA_VISIBLE_DEVICES
andLD_PRELOAD
are correctly set (innvshare
or the pytorch code) ?The text was updated successfully, but these errors were encountered: