Skip to content

[1/n] Migrate activation kernels to libtorch stable ABI#30908

Open
mikaylagawarecki wants to merge 2 commits intovllm-project:mainfrom
mikaylagawarecki:torch_stable_abi
Open

[1/n] Migrate activation kernels to libtorch stable ABI#30908
mikaylagawarecki wants to merge 2 commits intovllm-project:mainfrom
mikaylagawarecki:torch_stable_abi

Conversation

@mikaylagawarecki
Copy link
Copy Markdown
Contributor

@mikaylagawarecki mikaylagawarecki commented Dec 17, 2025

Purpose

This change requires 2.10+ to build

First step towards migration of the cuda wheel to the libtorch stable ABI see #26946. The benefit of migrating to the libtorch stable ABI is that the stable binary will be able to be built once (with torch 2.10 or greater) and can run in all versions of torch >= 2.10.

  1. activation_kernels.cu is migrated to use the stable ABI/API
  2. Creates csrc/torch_bindings_stable.cpp that registers kernels in activation_kernels.cu via TORCH_STABLE_LIBRARY_FRAGMENT to the _C namespace
  3. Creates a new _C_stable.so in addition to the existing _C.so that has sources csrc/activation_kernels.cu
    csrc/torch_bindings_stable.cpp

I am looking for feedback on the following

  1. The PR separates the stable kernels into a new _C_stable target, is this ok?
    a. The rationale is that this extension is built with -DTORCH_TARGET_VERSION=... (ensuring only stable APIs are used)
    b. Kernels in this extension are still registered to the _C namespace for backward compat.
  2. The PR requires torch version >=2.10 in order to build _C_stable, I'm wondering whether that is alright
    a. It can be built once with any version >= 2.10 and will be able to run in all versions of torch >= 2.10

Test Plan

pytest tests/kernels/core/test_activation.py

Test Result

Screenshot 2025-12-17 at 5 47 19 PM ---
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request migrates the activation kernels to use the PyTorch stable ABI. This is a good step towards improving the long-term maintainability of the CUDA kernels by depending on a stable API. The changes are well-structured, creating a new _C_stable extension for the migrated kernels and updating the build system accordingly.

My review focuses on the implementation details of this migration. I've identified an opportunity to improve the maintainability of the new CUDA code by refactoring duplicated logic in the kernel launch macros. This will make the code cleaner and less prone to errors in future modifications.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify mergify bot added the cpu Related to CPU backends label Dec 17, 2025
@mikaylagawarecki
Copy link
Copy Markdown
Contributor Author

@Harry-Chen
Copy link
Copy Markdown
Member

Thank you for putting it together! This is definitely what we need.

A general question: does it make sense to migrate to libtorch stable ABI in partial? In other words,
the remaining use of libtorch ("not stable") ABIs still exist, and they will still bind us to a specific version of libtorch.so.
Of course, this work is not easy and could be done part by part. But unless we could get rid of all of them, this migration will not take actual effect.

Other minor comments for your changes:

The PR separates the stable kernels into a new _C_stable target, is this ok?

What about a name like _C_stable_libtorch, just for more clarity?

b. Kernels in this extension are still registered to the _C namespace for backward compat.

IIUC this will lead to duplicate symbols in both library files. Is this necessary, i.e., who might still be using the old symbols?

The PR requires torch version >=2.10 in order to build _C_stable, I'm wondering whether that is alright

We could do this after the release of PyTorch 2.10. I think that is about ~1 month later. And during this period we could run build and test on nightly pytorch.

@mikaylagawarecki
Copy link
Copy Markdown
Contributor Author

mikaylagawarecki commented Dec 19, 2025

Hi @Harry-Chen,

Thank you for your feedback!

But unless we could get rid of all of them, this migration will not take actual effect.

I think you are right here. The rationale for starting this with an initial small PR is per youkaichao's feedback on the initial issue

Is it possible to use stable APIs gradually? Like gradually enable it in some files. If we have to do it once or nothing, I'm afraid that would be a very large PR (similar to the FA3 one) and never get merged (vLLM's kernel changes are quite frequent).

Indeed, the final switch of building a libtorch ABI stable wheel would have to wait till all the relevant files are only using stable APIs, but the progress does not need to be all or nothing. Does it make sense to you if I continue this enablement for the files in the CUDA wheel with a stack that migrates files one by one?

What about a name like _C_stable_libtorch, just for more clarity?

👍 Will fix!

IIUC this will lead to duplicate symbols in both library files. Is this necessary, i.e., who might still be using the old symbols?

To clarify what I meant here, I meant that the C_stable target uses a TORCH_LIBRARY_FRAGMENT with the namespace _C. This ensures that the function is still callable from python as (e.g. torch.ops._C.silu_and_mul).

I don't think there are duplicate symbols because I removed the registrations from the respective TORCH_LIBRARY in the _C target

We could do this after the release of PyTorch 2.10. I think that is about ~1 month later. And during this period we could run build and test on nightly pytorch

That sounds good! Is there a guide on how to enable this in CI?

@Harry-Chen
Copy link
Copy Markdown
Member

@mikaylagawarecki

Indeed, the final switch of building a libtorch ABI stable wheel would have to wait till all the relevant files are only using stable APIs, but the progress does not need to be all or nothing. Does it make sense to you if I continue this enablement with a stack that migrates files one by one?

Yes this totally makes sense. So our target is to move all c++ compilation units to VLLM_STABLE_EXT_SRC. Do you have any evaluation on this, e.g. can we finish this with PyTorch 2.10, or will we be blocked by some APIs that do not exist yet?

To clarify what I meant here, I meant that the stable library in C_stable registers a TORCH_LIBRARY_FRAGMENT with the namespace _C. This ensures that the function is still callable from python as torch._C.silu_and_mul.

I don't think there are duplicate symbols because I removed the registrations from the respective TORCH_LIBRARY in the _C target

I got it now. Thank you for explaining!

We could do this after the release of PyTorch 2.10. I think that is about ~1 month later. And during this period we could run build and test on nightly pytorch

That sounds good! Is there a guide on how to enable this in CI?

Maybe you can refer to https://docs.vllm.ai/en/latest/contributing/ci/update_pytorch_version/#test-pytorch-release-candidates-rcs. And CC @youkaichao and @khluu for more ideas here.

@mikaylagawarecki
Copy link
Copy Markdown
Contributor Author

mikaylagawarecki commented Dec 19, 2025

Yes this totally makes sense. So our target is to move all c++ compilation units to VLLM_STABLE_EXT_SRC. Do you have any evaluation on this, e.g. can we finish this with PyTorch 2.10, or will we be blocked by some APIs that do not exist yet?

We had gone over the APIs in the CUDA build of vllm before and I believe we should have what is needed for vllm/_C.abi3.so, vllm/_moe_C.abi3.so and vllm/cumem_allocator.abi3.so in torch 2.10 (note this doesn't include rocm wheel or cpu wheel etc.). That said there might be unknown-unknowns (e.g. kernels that changed or APIs that were missed). So I cannot say this with 100% confidence until I migrate the kernels (which I intend to do asap :))

Would you be in objection of the partial migration if we find that there are a few kernels that must be unstable for 2.10?

I am aware that vllm also has dependencies on FlashMLA, qutlass and flash attention2/3. Would migrating these be a prerequisite? Currently only fa3 is on torch's stable abi

@Harry-Chen
Copy link
Copy Markdown
Member

We had gone over the APIs in the CUDA build of vllm before and I believe we should have what is needed for vllm/_C.abi3.so, vllm/_moe_C.abi3.so and vllm/cumem_allocator.abi3.so in torch 2.10 (note this doesn't include rocm wheel or cpu wheel etc.). That said there might be unknown-unknowns (e.g. kernels that changed or APIs that were missed). So I cannot say this with 100% confidence until I migrate the kernels (which I intend to do asap :))

Would you be in objection of the partial migration if we find that there are a few kernels that must be unstable for 2.10?

Of course not, any progress would be good, even though not 100%, as long as we are going towards it.

I am aware that vllm also has dependencies on FlashMLA, qutlass and flash attention2/3. Would migrating these be a prerequisite? Currently only fa3 is on torch's stable abi

The ideal situation is that everything gets migrated. Please note the flashmla and flashattention we use in vllm are forks (under vllm-project on GitHub). So this might require extra work (either cherry-picking upstream's changes, rebasing, or we do it on our own). You could also evaluate the amount of work required to do so, and we can discuss later.

@mikaylagawarecki mikaylagawarecki changed the title Migrate activation kernels to libtorch stable ABI [1/n] Migrate activation kernels to libtorch stable ABI Dec 29, 2025
@mikaylagawarecki mikaylagawarecki force-pushed the torch_stable_abi branch 3 times, most recently from f6f00f5 to f57fbaa Compare January 15, 2026 22:28
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 15, 2026

Documentation preview: https://vllm--30908.org.readthedocs.build/en/30908/

@mikaylagawarecki mikaylagawarecki force-pushed the torch_stable_abi branch 2 times, most recently from 208d7e5 to 5c06dc5 Compare January 16, 2026 22:08
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 16, 2026

Hi @mikaylagawarecki, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 16, 2026

Hi @mikaylagawarecki, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Cherry-picked from temp4 996b0bb with the following additional changes:

- Merged M6 Blackwell 256-bit vectorization (PackedTraits, ld256/st256,
  cc_major branching) with stable ABI APIs (DeviceGuard,
  get_current_cuda_stream, mutable_data_ptr/const_data_ptr,
  VLLM_STABLE_DISPATCH_FLOATING_TYPES)
- Added get_device_prop() cached device properties utility to
  csrc/stable/torch_utils.h for stable ABI compatible device queries
- Added explicit cuda_bf16.h/cuda_fp16.h includes for packed type
  intrinsics (hip equivalents on ROCm)
- Replaced c10::BFloat16/c10::Half with torch::headeronly:: equivalents
- ROCm fix: cuda_compat.h -> ../cuda_compat.h in stable/activation_kernels.cu
- ROCm fix: import _C_stable_libtorch in vllm/platforms/rocm.py

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES})
else()
target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES})
# Link against PyTorch's bundled libtorch_hip.so (for DeviceGuard registration)
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki mikaylagawarecki Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a RoCM machine, I found that I needed these changes to make sure that vllm was linking to these two .so from torch (in particular the libamd.so hip headers packaged by torch) otherwise the stable DeviceGuard would not work correctly

It seemed that 2 separate hipContexts were created by the raw hip calls that vllm did (which were from a hip header from elsewhere) and the libtorch shims that called raw hip APIs, which used the hip headers packaged by torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikaylagawarecki I'm not sure what's going on here. Do all extensions need this code? Or is vLLM doing something weird?

at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki mikaylagawarecki Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it might be confusing why this is deleted -- I kept the other declarations because the cpu torch_bindings.cpp includes ops.h. All the other ops are also defined for CPU as well, but this op isn't defined for CPU so its declaration is deleted

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 28, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 28, 2026
Comment on lines +988 to +991
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
# This ensures we only use C-shim APIs available in PyTorch 2.10+.
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment that "_C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION which is currently set to 2.10". (explicitly state what we get)

@@ -0,0 +1,25 @@
/*
* Stable ABI compatible dispatch utilities for vLLM.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Whenever you add "stable abi" in a comment you probably want to call it "libtorch stable abi" to clarify

STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
// Activation ops
// Activation function used in SwiGLU.
m.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check... you are able to add tags to operators in the stable abi?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean adding at::Tag to the registration? that is currently not ABI stable (but I don't believe at::Tag is used within the repo atm)

#include <mutex>
#include <string>
#include <vector>

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no namespacing at all, is that intentional?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like vllm/csrc
/ops.h has no namespacing in the first place lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends documentation Improvements or additions to documentation needs-rebase nvidia rocm Related to AMD ROCm

Projects

Status: Todo
Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants