Skip to content

[6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI#38757

Closed
mikaylagawarecki wants to merge 21 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase6
Closed

[6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI#38757
mikaylagawarecki wants to merge 21 commits into
vllm-project:mainfrom
mikaylagawarecki:new-stable-abi-phase6

Conversation

@mikaylagawarecki

@mikaylagawarecki mikaylagawarecki commented Apr 1, 2026

Copy link
Copy Markdown
Contributor

Stacked on #38671, only the top 11 commits are relevant. Commits to review https://github.com/vllm-project/vllm/pull/38757/changes/2c13410412de95b648e9cd8562431dbe9481f9ee..deea6618c38afb4735b442c61e2697c273654292

Note: some declarations are not deleted from csrc/ops.h despite being moved to csrc/libtorch_stable/ops.h. This is because the CPU build also uses these declarations. These are

  • Activation kernels: silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, gelu_new, gelu_fast, gelu_quick
  • W8A8 INT8 quantization: static_scaled_int8_quant, dynamic_scaled_int8_quant

Purpose

#26946

Test Plan

pytest tests/kernels/core/test_activation.py
pytest tests/kernels/quantization/test_ggml.py
pytest tests/kernels/quantization/test_fp8_quant.py
pytest tests/kernels/quantization/test_int8_quant.py

Test Result

Screenshot 2026-04-01 at 5 55 34 PM Screenshot 2026-04-02 at 1 13 16 PM Screenshot 2026-04-02 at 1 17 13 PM Screenshot 2026-04-02 at 1 21 39 PM
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify Bot added ci/build nvidia rocm Related to AMD ROCm labels Apr 1, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 1, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports several CUDA kernels, including activation, AWQ, AllSpark, MLA, and Hadacore, to the stable ABI extension (_C_stable_libtorch) and enables this extension for ROCm (HIP). Key changes include the introduction of a device property cache using raw CUDA/HIP APIs, updating headers to use header-only Torch utilities, and refactoring source files to use torch::stable::Tensor. Review feedback identified a compilation error in torch_utils.h where std::once_flag cannot be used with std::deque::resize due to its non-copyable nature, and logic errors in hadacore_transform related to in-place processing and padding.

Comment on lines +15 to +39
#include <deque>
#include <mutex>
#include <string>
#include <vector>

// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)

// Device properties cache for stable ABI compatibility.
// Uses raw CUDA/HIP APIs instead of ATen functions.
// Using inline ensures a single instance across all translation units.
inline std::deque<std::once_flag> device_flags;
inline std::vector<cudaDeviceProp> device_properties;
inline std::once_flag vectors_init_flag;

inline void do_init_device_vectors() {
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " +
std::string(cudaGetErrorString(err)));
}
device_flags.resize(device_count);
device_properties.resize(device_count);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of std::deque<std::once_flag> and calling resize() on it will cause a compilation error because std::once_flag is non-copyable and non-movable. std::deque::resize requires the element type to be MoveInsertable. A better approach is to use a std::unique_ptr<std::once_flag[]> to manage the flags dynamically after querying the device count. This requires including <memory>.

#include <memory>
#include <mutex>
#include <string>
#include <vector>

// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
  STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)

// Device properties cache for stable ABI compatibility.
// Uses raw CUDA/HIP APIs instead of ATen functions.
// Using inline ensures a single instance across all translation units.
inline std::unique_ptr<std::once_flag[]> device_flags;
inline std::vector<cudaDeviceProp> device_properties;
inline std::once_flag vectors_init_flag;

inline void do_init_device_vectors() {
  int device_count;
  cudaError_t err = cudaGetDeviceCount(&device_count);
  if (err != cudaSuccess) {
    STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " +
                               std::string(cudaGetErrorString(err)));
  }
  device_flags = std::make_unique<std::once_flag[]>(device_count);
  device_properties.resize(device_count);
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment on [5/n] #38671 (comment)

Comment on lines 799 to +807
hadacore::run_fht<SCALAR_TYPE>(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream);
});

if (numel % 256 != 0) {
out = out.narrow(0, 0, numel / had_size);
out = torch::stable::narrow(out, 0, 0, numel / had_size);
}

if (inplace && out.data_ptr() != x.data_ptr()) {
x.copy_(out.view(res_shape));
torch::stable::copy_(x, torch::stable::view(out, res_shape));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are critical logic errors in hadacore_transform regarding in-place operations and padding:

  1. Incorrect Output Pointer: The kernel run_fht is called with x.data_ptr() as the output pointer (line 799). If inplace is false, the input tensor x is modified, and the returned tensor out remains uninitialized.
  2. Incorrect Copy-back Logic: When inplace is true and padding occurs, x is reassigned to a new padded tensor. The result is written to this new tensor. However, line 807 copies from out (the original tensor) to x (the result tensor), which overwrites the computed result with the original input data. It should copy from x back to out.

To fix this, the kernel should write to out.data_ptr(). If inplace is true and padding occurred, the result should be copied from out (the padded result) back to the original input storage (after narrowing).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-existing?

@mikaylagawarecki mikaylagawarecki changed the title Migrate some shared CUDA/RoCM kernels to libtorch stable ABI [6/n] Migrate some shared CUDA/RoCM kernels to libtorch stable ABI Apr 1, 2026
@mikaylagawarecki mikaylagawarecki force-pushed the new-stable-abi-phase6 branch 2 times, most recently from 969bfb0 to 60e21ce Compare April 1, 2026 22:28
@mergify

mergify Bot commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 2, 2026
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pure move, no code changes. Preparatory step for stable ABI migration.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Restructure the stable ABI extension build so it compiles on both CUDA
and HIP:
- Widen outer guard to include HIP
- Move CUDA-only sources (CUTLASS, FP4, AWQ, permute_cols) into
  a CUDA-conditional block
- Gate USE_CUDA / CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL to CUDA;
  define USE_ROCM for HIP
- Link PyTorch's bundled libamdhip64.so on ROCm to avoid a dual HIP
  runtime (from 985769a)
- Enable _C_stable_libtorch in setup.py for HIP builds

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Move 9 basic activation ops (silu_and_mul, mul_and_silu, gelu_and_mul,
gelu_tanh_and_mul, fatrelu_and_mul, swigluoai_and_mul, gelu_new,
gelu_fast, gelu_quick) from the _C extension to _C_stable_libtorch.

Convert ATen types/APIs to stable ABI equivalents:
- torch::Tensor -> torch::stable::Tensor
- ATen device guard/stream -> stable accelerator APIs
- VLLM_DISPATCH_FLOATING_TYPES -> VLLM_STABLE_DISPATCH_FLOATING_TYPES
- data_ptr -> mutable_data_ptr

Quantized activation ops (silu_and_mul_quant,
persistent_masked_m_silu_mul_quant) remain in _C.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
@@ -1,17 +1,20 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../../cuda_compat.h"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes to use relative paths are due to hipification errors I ran into after moving files into the libtorch_stable subdirectory

torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale, // various shapes
std::optional<std::tuple<int64_t, int64_t>>

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schema change here is bc stableivalue conversions don't support tuple, from the user perspective in python there is no change, also we assert the size below on 219

Comment thread csrc/ops.h
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: some declarations are not removed a the cpu build is using the declarations from ops.h too (tests will fail if we remove them)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you track these in the PR description?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@mikaylagawarecki mikaylagawarecki force-pushed the new-stable-abi-phase6 branch 2 times, most recently from 675961a to e8f3b67 Compare April 2, 2026 18:36
Migrate static_scaled_fp8_quant, dynamic_scaled_fp8_quant, and
dynamic_per_token_scaled_fp8_quant from _C to _C_stable_libtorch.

Shared headers (common.cuh, utils.cuh) updated to work with both
targets: utils.cuh uses torch::headeronly types; common.cuh uses

Schema changed from (int,int)? to int[]? for group_shape to work
with TORCH_BOX (std::tuple is not trivially copyable).

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Migrate gptq_gemm and gptq_shuffle from _C to _C_stable_libtorch.

Key conversions: torch::Tensor -> torch::stable::Tensor, cuBLAS handle
via get_current_cuda_blas_handle(), device().is_meta() check via
DeviceType::Meta comparison, tensor creation via new_zeros/empty.

Also removes #ifndef USE_ROCM guard on get_current_cuda_blas_handle()
in torch_utils.h — hipify handles cublas_v2.h -> hipblas/hipblas.h
and cublasHandle_t -> hipblasHandle_t automatically.

Sub-headers (compat.cuh, matrix_view.cuh, qdq_*.cuh) are pure CUDA
device code and need no changes.

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Claude

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Co-authored-by: Claude

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Comment thread csrc/ops.h
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you track these in the PR description?

Comment thread csrc/cuda_vec_utils.cuh

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these?

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete error without these is

AILED: [code=1] CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o 
ccache /opt/rocm/lib/llvm/bin/clang++  -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_C_stable_libtorch -DTORCH_TARGET_VERSION=0x020A000000000000ULL -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_ROCM -DUSE_RPC -DUSE_TENSORPIPE -D_C_stable_libtorch_EXPORTS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc -isystem /home/mg1998/.conda/envs/pytorch/include/python3.12 -isystem /home/mg1998/.conda/envs/pytorch/lib/python3.12/site-packages/torch/include -isystem /home/mg1998/.conda/envs/pytorch/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/fbcode/platform010/lib/rocm-6.4.2/include/hiprand -isystem /usr/local/fbcode/platform010/lib/rocm-6.4.2/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx942 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -DTORCH_HIP_VERSION=701 -Wno-shift-count-negative -Wno-shift-count-overflow -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIP_ENABLE_WARP_SYNC_BUILTINS -DHIPBLASLT_OUTER_VEC -DUSE_ROCM_CK_GEMM -MD -MT CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o -MF CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o.d -o CMakeFiles/_C_stable_libtorch.dir/csrc/libtorch_stable/activation_kernels.hip.o -x hip -c /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip
In file included from /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip:8:
/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/../hip_vec_utils.cuh:78:28: error: unknown type name '__hip_bfloat162'; did you mean 'hip_bfloat16'?
   78 | struct PackedTypeConverter<__hip_bfloat162> {
      |                            ^~~~~~~~~~~~~~~
      |                            hip_bfloat16
/opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
   57 | struct hip_bfloat16
      |        ^
In file included from /data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/activation_kernels.hip:8:
/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312/csrc/libtorch_stable/../hip_vec_utils.cuh:79:16: error: unknown type name '__hip_bfloat16'; did you mean 'hip_bfloat16'?
   79 |   using Type = __hip_bfloat16;
      |                ^~~~~~~~~~~~~~
      |                hip_bfloat16
/opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h:57:8: note: 'hip_bfloat16' declared here
   57 | struct hip_bfloat16
      |        ^

Ah concretely it seems like there is a bug here https://github.com/pytorch/pytorch/blob/main/torch/headeronly/util/BFloat16.h#L15-L17, the #include <cuda_bf16.h> gets hipified but the !defined(USE_ROCM) doeesn't

We defined USE_ROCM for _C_stable_libtorch to expose some of the shims that are gated :/

Comment thread CMakeLists.txt
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come this common.cpp needs to move to if CUDA?

@mikaylagawarecki mikaylagawarecki Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: hmm wait just to confirm you were referring to csrc/cuda_utils_kernels.cu not common.cpp (which is correctly CUDA-only) right?

technically csrc/cuda_utils_kernels.cu should be shared cuda/rocm, but there's some issues with it being in sources for both extensions when building on rocm, so I want to punt that problem which will be solved when we fully migrate it out of _C

The error looks like

CMake Error:
  Running

   '/home/mg1998/.conda/envs/pytorch/bin/ninja' '-C' '/data/users/mg1998/vllm/build/temp.linux-x86_64-cpython-312' '-t' 'recompact'

  failed with:

   ninja: error: build.ninja:713: multiple rules generate csrc/hip_utils_kernels.hip

Comment thread setup.py
# also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is
# fixed
if _is_cuda():
if _is_cuda() or _is_hip():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exciting

Comment thread CMakeLists.txt
# If PyTorch doesn't bundle libamdhip64 (built from source against system
# ROCm), there is only one copy in the process and no action is needed —
# the HIP compiler already links the system libamdhip64 automatically.
if(VLLM_GPU_LANG STREQUAL "HIP")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to improve my understanding, this code basically specifically picks out the amdhip64 that pytorch bundles in order to have deterministic correct results and not get corrupted?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea concretely there seem to be two cases

  1. pip installed pytorch on rocm (local) -- torch bundles libamdhip64, if we have raw hipFoo calls in vllm they will use the system rocm, but calls into hipFoo from libtorch itself will use torch's bundled libamdhip64 and there will be two device contexts -- we get gpucore dumps
  2. rocm pytorch docker image (e.g. vllm CI) -- torch does not bundle libamdhip64, we are good.

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review April 3, 2026 15:09
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Apr 3, 2026
@mergify

mergify Bot commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify

mergify Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mikaylagawarecki.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 18, 2026
@janeyx99

Copy link
Copy Markdown
Contributor

This PR will be landed in #42663 and can be closed

@Harry-Chen

Copy link
Copy Markdown
Member

Superseded by newer PRs.

@Harry-Chen Harry-Chen closed this May 20, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 20, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants