Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ set(VLLM_EXT_SRC
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
Expand Down Expand Up @@ -960,6 +959,42 @@ define_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)


if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/stable/activation_kernels.cu"
"csrc/stable/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
endif()

message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

# Set TORCH_TARGET_VERSION for stable ABI compatibility.
# This ensures we only use C-shim APIs available in PyTorch 2.10+.
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
Comment on lines +988 to +991
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment that "_C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION which is currently set to 2.10". (explicitly state what we get)


# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
endif()

#
# _moe_C extension
#
Expand Down
20 changes: 19 additions & 1 deletion cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,25 @@ function (define_extension_target MOD_NAME)
if (ARG_LANGUAGE STREQUAL "CUDA")
target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES})
else()
target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES})
# Link against PyTorch's bundled libtorch_hip.so (for DeviceGuard registration)
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki mikaylagawarecki Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a RoCM machine, I found that I needed these changes to make sure that vllm was linking to these two .so from torch (in particular the libamd.so hip headers packaged by torch) otherwise the stable DeviceGuard would not work correctly

It seemed that 2 separate hipContexts were created by the raw hip calls that vllm did (which were from a hip header from elsewhere) and the libtorch shims that called raw hip APIs, which used the hip headers packaged by torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikaylagawarecki I'm not sure what's going on here. Do all extensions need this code? Or is vLLM doing something weird?

# and libamdhip64.so (to share a single HIP runtime with PyTorch).
find_library(TORCH_HIP_LIBRARY torch_hip PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
find_library(TORCH_AMDHIP64_LIBRARY amdhip64 PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)

set(_hip_libs)
if (TORCH_HIP_LIBRARY)
list(APPEND _hip_libs ${TORCH_HIP_LIBRARY})
endif()
if (TORCH_AMDHIP64_LIBRARY)
list(APPEND _hip_libs ${TORCH_AMDHIP64_LIBRARY})
# Ensure PyTorch's bundled libamdhip64.so is found at runtime, not system ROCm's.
set(_torch_lib_dir "${TORCH_INSTALL_PREFIX}/lib")
set_target_properties(${MOD_NAME} PROPERTIES
BUILD_RPATH "${_torch_lib_dir}"
INSTALL_RPATH "${_torch_lib_dir}")
endif()

target_link_libraries(${MOD_NAME} PRIVATE torch ${_hip_libs} ${TORCH_LIBRARIES} ${ARG_LIBRARIES})
endif()

install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME})
Expand Down
2 changes: 0 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ void persistent_masked_m_silu_mul_quant(
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
Copy link
Copy Markdown
Contributor Author

@mikaylagawarecki mikaylagawarecki Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it might be confusing why this is deleted -- I kept the other declarations because the cpu torch_bindings.cpp includes ops.h. All the other ops are also defined for CPU as well, but this op isn't defined for CPU so its declaration is deleted


void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
300 changes: 162 additions & 138 deletions csrc/activation_kernels.cu → csrc/stable/activation_kernels.cu

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions csrc/stable/dispatch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Stable ABI compatible dispatch utilities for vLLM.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Whenever you add "stable abi" in a comment you probably want to call it "libtorch stable abi" to clarify

* Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros
* instead of the ATen (AT_*) macros.
*
* These macros use:
* - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH
* - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE
* - torch::headeronly::ScalarType instead of at::ScalarType
*
* Add more macros here as needed when migrating additional kernels.
*/
#pragma once

#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/ScalarType.h>

#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)

#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
20 changes: 20 additions & 0 deletions csrc/stable/ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>

// Gated activation functions (input: [..., 2*d] -> output: [..., d])
void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_tanh_and_mul(torch::stable::Tensor& out,
torch::stable::Tensor& input);
void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
double threshold);
void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input,
double alpha, double limit);

// Element-wise activation functions (input: [..., d] -> output: [..., d])
void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input);
void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input);
54 changes: 54 additions & 0 deletions csrc/stable/torch_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "ops.h"
#include "core/registration.h"

#include <torch/csrc/stable/library.h>

// Register ops using STABLE_TORCH_LIBRARY for stable ABI compatibility.
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
// Activation ops
// Activation function used in SwiGLU.
m.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check... you are able to add tags to operators in the stable abi?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean adding at::Tag to the registration? that is currently not ABI stable (but I don't believe at::Tag is used within the repo atm)


m.def("mul_and_silu(Tensor! out, Tensor input) -> ()");

// Activation function used in GeGLU with `none` approximation.
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");

// Activation function used in GeGLU with `tanh` approximation.
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");

// FATReLU implementation.
m.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");

m.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) -> ()");

// GELU implementation used in GPT-2.
m.def("gelu_new(Tensor! out, Tensor input) -> ()");

// Approximate GELU implementation.
m.def("gelu_fast(Tensor! out, Tensor input) -> ()");

// Quick GELU implementation.
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
}

STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
// Gated activations
m.impl("silu_and_mul", TORCH_BOX(&silu_and_mul));
m.impl("mul_and_silu", TORCH_BOX(&mul_and_silu));
m.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul));
m.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul));
m.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul));
m.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul));

// Element-wise activations
m.impl("gelu_new", TORCH_BOX(&gelu_new));
m.impl("gelu_fast", TORCH_BOX(&gelu_fast));
m.impl("gelu_quick", TORCH_BOX(&gelu_quick));
}

REGISTER_EXTENSION(_C_stable_libtorch)
60 changes: 60 additions & 0 deletions csrc/stable/torch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/Exception.h>

#include <cuda_runtime.h>
#include <deque>
#include <mutex>
#include <string>
#include <vector>

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no namespacing at all, is that intentional?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like vllm/csrc
/ops.h has no namespacing in the first place lol

// Device properties cache for stable ABI compatibility.
// Uses raw CUDA/HIP APIs instead of ATen functions.
// Thread-safe: each device's properties are queried exactly once.
inline std::deque<std::once_flag> device_prop_flags;
inline std::vector<cudaDeviceProp> device_prop_cache;
inline std::once_flag device_prop_vectors_init_flag;

inline void init_device_prop_vectors() {
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " +
std::string(cudaGetErrorString(err)));
}
device_prop_flags.resize(device_count);
device_prop_cache.resize(device_count);
}

inline void init_device_prop(int device_index) {
cudaDeviceProp prop{};
cudaError_t err = cudaGetDeviceProperties(&prop, device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_prop_cache[device_index] = prop;
}

inline cudaDeviceProp* get_device_prop() {
std::call_once(device_prop_vectors_init_flag, init_device_prop_vectors);
int device_index;
cudaError_t err = cudaGetDevice(&device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(
false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err)));
}
std::call_once(device_prop_flags[device_index], init_device_prop,
device_index);
return &device_prop_cache[device_index];
}

// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
return reinterpret_cast<cudaStream_t>(stream_ptr);
}
37 changes: 0 additions & 37 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
Expand All @@ -116,39 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
#endif

ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);

ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);

// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def extract_precompiled_and_patch_package(
with zipfile.ZipFile(wheel_path) as wheel:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_C_stable_libtorch.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/_flashmla_extension_C.abi3.so",
Expand Down Expand Up @@ -989,6 +990,8 @@ def _read_requirements(filename: str) -> list[str]:

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch"))

package_data = {
"vllm": [
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# import custom ops, trigger op registration
import vllm._C # noqa
import vllm._C_stable_libtorch # noqa
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)

try:
import vllm._C_stable_libtorch # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e)

# import custom ops, trigger op registration
try:
import vllm._rocm_C # noqa: F401
Expand Down