[1/n] Migrate activation kernels to libtorch stable ABI#30908
[1/n] Migrate activation kernels to libtorch stable ABI#30908mikaylagawarecki wants to merge 2 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request migrates the activation kernels to use the PyTorch stable ABI. This is a good step towards improving the long-term maintainability of the CUDA kernels by depending on a stable API. The changes are well-structured, creating a new _C_stable extension for the migrated kernels and updating the build system accordingly.
My review focuses on the implementation details of this migration. I've identified an opportunity to improve the maintainability of the new CUDA code by refactoring duplicated logic in the kernel launch macros. This will make the code cleaner and less prone to errors in future modifications.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
Thank you for putting it together! This is definitely what we need. A general question: does it make sense to migrate to libtorch stable ABI in partial? In other words, Other minor comments for your changes:
What about a name like
IIUC this will lead to duplicate symbols in both library files. Is this necessary, i.e., who might still be using the old symbols?
We could do this after the release of PyTorch 2.10. I think that is about ~1 month later. And during this period we could run build and test on nightly pytorch. |
|
Hi @Harry-Chen, Thank you for your feedback!
I think you are right here. The rationale for starting this with an initial small PR is per youkaichao's feedback on the initial issue
Indeed, the final switch of building a libtorch ABI stable wheel would have to wait till all the relevant files are only using stable APIs, but the progress does not need to be all or nothing. Does it make sense to you if I continue this enablement for the files in the CUDA wheel with a stack that migrates files one by one?
👍 Will fix!
To clarify what I meant here, I meant that the I don't think there are duplicate symbols because I removed the registrations from the respective TORCH_LIBRARY in the
That sounds good! Is there a guide on how to enable this in CI? |
Yes this totally makes sense. So our target is to move all c++ compilation units to
I got it now. Thank you for explaining!
Maybe you can refer to https://docs.vllm.ai/en/latest/contributing/ci/update_pytorch_version/#test-pytorch-release-candidates-rcs. And CC @youkaichao and @khluu for more ideas here. |
We had gone over the APIs in the CUDA build of vllm before and I believe we should have what is needed for Would you be in objection of the partial migration if we find that there are a few kernels that must be unstable for 2.10? I am aware that vllm also has dependencies on FlashMLA, qutlass and flash attention2/3. Would migrating these be a prerequisite? Currently only fa3 is on torch's stable abi |
Of course not, any progress would be good, even though not 100%, as long as we are going towards it.
The ideal situation is that everything gets migrated. Please note the flashmla and flashattention we use in vllm are forks (under vllm-project on GitHub). So this might require extra work (either cherry-picking upstream's changes, rebasing, or we do it on our own). You could also evaluate the amount of work required to do so, and we can discuss later. |
4331456 to
13423c9
Compare
f6f00f5 to
f57fbaa
Compare
|
Documentation preview: https://vllm--30908.org.readthedocs.build/en/30908/ |
208d7e5 to
5c06dc5
Compare
|
Hi @mikaylagawarecki, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
5c06dc5 to
13b890d
Compare
|
Hi @mikaylagawarecki, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
13b890d to
996b0bb
Compare
a97c8ed to
c913760
Compare
4b3a51d to
8c5be98
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
8c5be98 to
32f841a
Compare
Cherry-picked from temp4 996b0bb with the following additional changes: - Merged M6 Blackwell 256-bit vectorization (PackedTraits, ld256/st256, cc_major branching) with stable ABI APIs (DeviceGuard, get_current_cuda_stream, mutable_data_ptr/const_data_ptr, VLLM_STABLE_DISPATCH_FLOATING_TYPES) - Added get_device_prop() cached device properties utility to csrc/stable/torch_utils.h for stable ABI compatible device queries - Added explicit cuda_bf16.h/cuda_fp16.h includes for packed type intrinsics (hip equivalents on ROCm) - Replaced c10::BFloat16/c10::Half with torch::headeronly:: equivalents - ROCm fix: cuda_compat.h -> ../cuda_compat.h in stable/activation_kernels.cu - ROCm fix: import _C_stable_libtorch in vllm/platforms/rocm.py Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
32f841a to
39e1828
Compare
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
39e1828 to
e8d1957
Compare
| target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES}) | ||
| else() | ||
| target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) | ||
| # Link against PyTorch's bundled libtorch_hip.so (for DeviceGuard registration) |
There was a problem hiding this comment.
On a RoCM machine, I found that I needed these changes to make sure that vllm was linking to these two .so from torch (in particular the libamd.so hip headers packaged by torch) otherwise the stable DeviceGuard would not work correctly
It seemed that 2 separate hipContexts were created by the raw hip calls that vllm did (which were from a hip header from elsewhere) and the libtorch shims that called raw hip APIs, which used the hip headers packaged by torch
There was a problem hiding this comment.
@mikaylagawarecki I'm not sure what's going on here. Do all extensions need this code? Or is vLLM doing something weird?
| at::Tensor& y_s, // (E, T, H//group_size) [OUT] | ||
| bool use_ue8m0); | ||
|
|
||
| void mul_and_silu(torch::Tensor& out, torch::Tensor& input); |
There was a problem hiding this comment.
Since it might be confusing why this is deleted -- I kept the other declarations because the cpu torch_bindings.cpp includes ops.h. All the other ops are also defined for CPU as well, but this op isn't defined for CPU so its declaration is deleted
|
This pull request has merge conflicts that must be resolved before it can be |
| # Set TORCH_TARGET_VERSION for stable ABI compatibility. | ||
| # This ensures we only use C-shim APIs available in PyTorch 2.10+. | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE | ||
| TORCH_TARGET_VERSION=0x020A000000000000ULL) |
There was a problem hiding this comment.
nit: comment that "_C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION which is currently set to 2.10". (explicitly state what we get)
| @@ -0,0 +1,25 @@ | |||
| /* | |||
| * Stable ABI compatible dispatch utilities for vLLM. | |||
There was a problem hiding this comment.
nit: Whenever you add "stable abi" in a comment you probably want to call it "libtorch stable abi" to clarify
| STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { | ||
| // Activation ops | ||
| // Activation function used in SwiGLU. | ||
| m.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); |
There was a problem hiding this comment.
Just to check... you are able to add tags to operators in the stable abi?
There was a problem hiding this comment.
Do you mean adding at::Tag to the registration? that is currently not ABI stable (but I don't believe at::Tag is used within the repo atm)
| #include <mutex> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
There was a problem hiding this comment.
there's no namespacing at all, is that intentional?
Purpose
This change requires 2.10+ to build
First step towards migration of the cuda wheel to the libtorch stable ABI see #26946. The benefit of migrating to the libtorch stable ABI is that the stable binary will be able to be built once (with torch 2.10 or greater) and can run in all versions of torch >= 2.10.
activation_kernels.cuis migrated to use the stable ABI/APIcsrc/torch_bindings_stable.cppthat registers kernels in activation_kernels.cu viaTORCH_STABLE_LIBRARY_FRAGMENTto the_Cnamespace_C_stable.soin addition to the existing_C.sothat has sourcescsrc/activation_kernels.cucsrc/torch_bindings_stable.cppI am looking for feedback on the following
_C_stabletarget, is this ok?a. The rationale is that this extension is built with -DTORCH_TARGET_VERSION=... (ensuring only stable APIs are used)
b. Kernels in this extension are still registered to the _C namespace for backward compat.
_C_stable, I'm wondering whether that is alrighta. It can be built once with any version >= 2.10 and will be able to run in all versions of torch >= 2.10
Test Plan
pytest tests/kernels/core/test_activation.pyTest Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.