-
Notifications
You must be signed in to change notification settings - Fork 4.7k
[build] fix computer capability arch flags, add PTX, handle PTX #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
op_builder/builder.py
Outdated
| args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') | ||
| args.append(f'-gencode=arch=compute_{cc},code=sm_{cc}') | ||
| if cc.endswith('+PTX'): | ||
| args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we use the compute-capability (cc) number alone, rather than concatenated with '+PTX'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: I misunderstood what you said, you were pointing at a bug - fixed now. thank you!
But I will leave what I originally wrote below if it helps to explain things in general.
There are 2 unrelated things.
-
The code you quoted above does the right thing (edit: after I fixed the bug below) if a user passes
TORCH_CUDA_ARCH_LISTthat contains+PTX, as my previous PR missed that part (I had no clue how +PTX was converted to compile flags at that time, but now I do) -
The question is whether to add
+PTXfor jit_mode or not here:
https://github.com/microsoft/DeepSpeed/blob/03111cef09b33a25792246785ab0ded68be1733c/op_builder/builder.py#L246
I added it to sync with pytorch CUDAExtension implementation, which is how it'll do it once pytorch-1.8 is released. But you can choose to not include it. I have no idea how and when jit_mode is used to tell whether it's needed or not.
So to conclude the code you quoted above does the right thing if any of the ccs contains '+PTX, which may come from user-defined TORCH_CUDA_ARCH_LIST`- so it's a must. But wrt to (2) you can choose not to include it.
If you do - please let me know and I will remove https://github.com/microsoft/DeepSpeed/blob/03111cef09b33a25792246785ab0ded68be1733c/op_builder/builder.py#L246
Let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I see what you meant, there is a bug in my code - sorry - I forgot to check the result of test runs - thank you for flagging that, @RezaYazdaniAminabadi.
|
Thanks @stas00 for the fix. 👍 |
Now that pytorch/pytorch#48891 has been merged I have a better understanding of how the arch code is generated.
It looks like your original code wasn't producing it correctly, as
-gencode=arch=compute_{cc},code=compute_{cc}generates a PTX code and not the normal compute code. Looking athttps://github.com/pytorch/pytorch/blob/4434c07a2c0ba4debc6330063546f600aee8deb3/torch/utils/cpp_extension.py#L1556-L1563
It needed to be
-gencode=arch=compute_{cc},code=sm_{cc}in your code in first place. notice that the last bit iscode=sm_and notcode=compute_.So as a follow up to #578, this PR
s/code=compute_/code=sm_/for normal archs+PTXfor the highest arch in the jit-mode+PTX(especially in the case ofTORCH_CUDA_ARCH_LISTwhich was missing from my previous PR)Do note, I'm new to all this so please kindly verify that I'm doing the right thing. I did compile and tested the resulting binary on my code only.
The weird syntax for PTX I copied from pytorch (link above), https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#just-in-time-compilation explains that this is how PTX is encoded in nvcc flags. Quote:
So '-gencode=arch=compute_{cc},code=compute_{cc}')` tells nvcc to enable PTX for that arch. Very confusing.
Thanks.