diff --git a/CMakeLists.txt b/CMakeLists.txt index a6f7f69468d1..645e07504b2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,7 +288,6 @@ set(VLLM_EXT_SRC "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/pos_encoding_kernels.cu" - "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" @@ -960,6 +959,42 @@ define_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + +if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") + # + # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) + # + set(VLLM_STABLE_EXT_SRC + "csrc/stable/activation_kernels.cu" + "csrc/stable/torch_bindings.cpp") + + if(VLLM_GPU_LANG STREQUAL "CUDA") + set_gencode_flags_for_srcs( + SRCS "${VLLM_STABLE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + endif() + + message(STATUS "Enabling C_stable extension.") + define_extension_target( + _C_stable_libtorch + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_STABLE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + + # Set TORCH_TARGET_VERSION for stable ABI compatibility. + # This ensures we only use C-shim APIs available in PyTorch 2.10+. + target_compile_definitions(_C_stable_libtorch PRIVATE + TORCH_TARGET_VERSION=0x020A000000000000ULL) + + # Needed to use cuda APIs from C-shim + target_compile_definitions(_C_stable_libtorch PRIVATE + USE_CUDA) +endif() + # # _moe_C extension # diff --git a/cmake/utils.cmake b/cmake/utils.cmake index bdb2ba74d944..448d7ae3fd8e 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -541,7 +541,25 @@ function (define_extension_target MOD_NAME) if (ARG_LANGUAGE STREQUAL "CUDA") target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES}) else() - target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) + # Link against PyTorch's bundled libtorch_hip.so (for DeviceGuard registration) + # and libamdhip64.so (to share a single HIP runtime with PyTorch). + find_library(TORCH_HIP_LIBRARY torch_hip PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) + find_library(TORCH_AMDHIP64_LIBRARY amdhip64 PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) + + set(_hip_libs) + if (TORCH_HIP_LIBRARY) + list(APPEND _hip_libs ${TORCH_HIP_LIBRARY}) + endif() + if (TORCH_AMDHIP64_LIBRARY) + list(APPEND _hip_libs ${TORCH_AMDHIP64_LIBRARY}) + # Ensure PyTorch's bundled libamdhip64.so is found at runtime, not system ROCm's. + set(_torch_lib_dir "${TORCH_INSTALL_PREFIX}/lib") + set_target_properties(${MOD_NAME} PROPERTIES + BUILD_RPATH "${_torch_lib_dir}" + INSTALL_RPATH "${_torch_lib_dir}") + endif() + + target_link_libraries(${MOD_NAME} PRIVATE torch ${_hip_libs} ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) endif() install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME}) diff --git a/csrc/ops.h b/csrc/ops.h index 5e2b475fa8c1..477019caa68f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -165,8 +165,6 @@ void persistent_masked_m_silu_mul_quant( at::Tensor& y_s, // (E, T, H//group_size) [OUT] bool use_ue8m0); -void mul_and_silu(torch::Tensor& out, torch::Tensor& input); - void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/activation_kernels.cu b/csrc/stable/activation_kernels.cu similarity index 77% rename from csrc/activation_kernels.cu rename to csrc/stable/activation_kernels.cu index 99fa42f75e99..54dabd112811 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/stable/activation_kernels.cu @@ -1,11 +1,24 @@ -#include -#include -#include +#include +#include +#include +#include +#include +#include #include +#include -#include "cuda_compat.h" +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#include "../cuda_compat.h" #include "dispatch_utils.h" +#include "torch_utils.h" namespace vllm { @@ -69,12 +82,12 @@ template struct PackedTraits; template <> -struct PackedTraits { +struct PackedTraits { using packed_t = __nv_bfloat162; }; template <> -struct PackedTraits { +struct PackedTraits { using packed_t = __half2; }; @@ -269,65 +282,69 @@ packed_gelu_tanh_kernel(const packed_t& val) { } // namespace vllm -// Launch activation and gating kernel. +// Launch activation and gating kernel using stable APIs. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = get_device_prop()->major; \ + int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int vec_size = support_vec / input.element_size(); \ + const bool use_vec = (d % vec_size == 0); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (cc_major >= 10 && num_tokens > 128) { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, true, true> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } else { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, true, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + ACT_FIRST, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ } -void silu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void silu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel, true); } -void mul_and_silu(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void mul_and_silu(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // applies the silu to the latter half of the input. @@ -335,15 +352,15 @@ void mul_and_silu(torch::Tensor& out, // [..., d] false); } -void gelu_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel, true); } -void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true); @@ -498,16 +515,17 @@ __global__ void swigluoai_and_mul_kernel( return; \ } \ dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int cc_major = get_device_prop()->major; \ int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ - int vec_size = support_vec / at::elementSize(dtype); \ + int vec_size = support_vec / input.element_size(); \ const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ if (cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTraits::packed_t, \ @@ -515,11 +533,11 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTraits::packed_t>, \ true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } else { \ - VLLM_DISPATCH_FLOATING_TYPES( \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::PackedTraits::packed_t, \ @@ -527,48 +545,53 @@ __global__ void swigluoai_and_mul_kernel( PACKED_KERNEL< \ typename vllm::PackedTraits::packed_t>, \ true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d, \ - PARAM); \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ }); \ } \ } else { \ dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ - vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - false><<>>( \ - out.data_ptr(), input.data_ptr(), d, PARAM); \ - }); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + dtype, "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param< \ + scalar_t, typename vllm::PackedTraits::packed_t, \ + KERNEL, \ + PACKED_KERNEL::packed_t>, \ + false><<>>( \ + out.mutable_data_ptr(), \ + input.const_data_ptr(), d, PARAM); \ + }); \ } -#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ - vllm::swigluoai_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, ALPHA, \ - LIMIT); \ +#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \ + vllm::swigluoai_and_mul_kernel> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d, \ + ALPHA, LIMIT); \ }); -void fatrelu_and_mul(torch::Tensor& out, // [..., d], - torch::Tensor& input, // [..., 2 * d] +void fatrelu_and_mul(torch::stable::Tensor& out, // [..., d], + torch::stable::Tensor& input, // [..., 2 * d] double threshold) { LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM( vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold); } -void swigluoai_and_mul(torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] + +void swigluoai_and_mul(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input, // [..., 2 * d] double alpha, double limit) { LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit); } + namespace vllm { // Element-wise activation kernel template. @@ -619,43 +642,44 @@ __global__ void activation_kernel( } // namespace vllm -// Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, true> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, true, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ - vllm::activation_kernel, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); \ +// Launch element-wise activation kernel using stable APIs. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = get_device_prop()->major; \ + int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int vec_size = support_vec / input.element_size(); \ + const bool use_vec = (d % vec_size == 0); \ + const torch::stable::accelerator::DeviceGuard device_guard( \ + input.get_device_index()); \ + const cudaStream_t stream = get_current_cuda_stream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (cc_major >= 10 && num_tokens > 128) { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, true> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } else { \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, true, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_STABLE_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ + vllm::activation_kernel, false> \ + <<>>(out.mutable_data_ptr(), \ + input.const_data_ptr(), d); \ + }); \ } namespace vllm { @@ -683,20 +707,20 @@ __device__ __forceinline__ T gelu_quick_kernel(const T& x) { } // namespace vllm -void gelu_new(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } -void gelu_quick(torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_quick(torch::stable::Tensor& out, // [..., d] + torch::stable::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); } diff --git a/csrc/stable/dispatch_utils.h b/csrc/stable/dispatch_utils.h new file mode 100644 index 000000000000..b7eacca04343 --- /dev/null +++ b/csrc/stable/dispatch_utils.h @@ -0,0 +1,25 @@ +/* + * Stable ABI compatible dispatch utilities for vLLM. + * Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros + * instead of the ATen (AT_*) macros. + * + * These macros use: + * - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH + * - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE + * - torch::headeronly::ScalarType instead of at::ScalarType + * + * Add more macros here as needed when migrating additional kernels. + */ +#pragma once + +#include +#include + +#define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) diff --git a/csrc/stable/ops.h b/csrc/stable/ops.h new file mode 100644 index 000000000000..a26b6f454c29 --- /dev/null +++ b/csrc/stable/ops.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +// Gated activation functions (input: [..., 2*d] -> output: [..., d]) +void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_tanh_and_mul(torch::stable::Tensor& out, + torch::stable::Tensor& input); +void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double threshold); +void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, + double alpha, double limit); + +// Element-wise activation functions (input: [..., d] -> output: [..., d]) +void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input); +void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input); diff --git a/csrc/stable/torch_bindings.cpp b/csrc/stable/torch_bindings.cpp new file mode 100644 index 000000000000..5e18a1a76372 --- /dev/null +++ b/csrc/stable/torch_bindings.cpp @@ -0,0 +1,54 @@ +#include "ops.h" +#include "core/registration.h" + +#include + +// Register ops using STABLE_TORCH_LIBRARY for stable ABI compatibility. +// Note: We register under namespace "_C" so ops are accessible as +// torch.ops._C. for compatibility with existing code. +STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { + // Activation ops + // Activation function used in SwiGLU. + m.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); + + m.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); + + // Activation function used in GeGLU with `none` approximation. + m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + + // Activation function used in GeGLU with `tanh` approximation. + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + + // FATReLU implementation. + m.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); + + m.def( + "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " + "limit=7.0) -> ()"); + + // GELU implementation used in GPT-2. + m.def("gelu_new(Tensor! out, Tensor input) -> ()"); + + // Approximate GELU implementation. + m.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + + // Quick GELU implementation. + m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + // Gated activations + m.impl("silu_and_mul", TORCH_BOX(&silu_and_mul)); + m.impl("mul_and_silu", TORCH_BOX(&mul_and_silu)); + m.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul)); + m.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul)); + m.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul)); + m.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul)); + + // Element-wise activations + m.impl("gelu_new", TORCH_BOX(&gelu_new)); + m.impl("gelu_fast", TORCH_BOX(&gelu_fast)); + m.impl("gelu_quick", TORCH_BOX(&gelu_quick)); +} + +REGISTER_EXTENSION(_C_stable_libtorch) diff --git a/csrc/stable/torch_utils.h b/csrc/stable/torch_utils.h new file mode 100644 index 000000000000..a15085c8f3da --- /dev/null +++ b/csrc/stable/torch_utils.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +// Device properties cache for stable ABI compatibility. +// Uses raw CUDA/HIP APIs instead of ATen functions. +// Thread-safe: each device's properties are queried exactly once. +inline std::deque device_prop_flags; +inline std::vector device_prop_cache; +inline std::once_flag device_prop_vectors_init_flag; + +inline void init_device_prop_vectors() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " + + std::string(cudaGetErrorString(err))); + } + device_prop_flags.resize(device_count); + device_prop_cache.resize(device_count); +} + +inline void init_device_prop(int device_index) { + cudaDeviceProp prop{}; + cudaError_t err = cudaGetDeviceProperties(&prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_prop_cache[device_index] = prop; +} + +inline cudaDeviceProp* get_device_prop() { + std::call_once(device_prop_vectors_init_flag, init_device_prop_vectors); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK( + false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err))); + } + std::call_once(device_prop_flags[device_index], init_device_prop, + device_index); + return &device_prop_cache[device_index]; +} + +// Utility to get the current CUDA stream for a given device using stable APIs. +// Returns a cudaStream_t for use in kernel launches. +inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); + return reinterpret_cast(stream_ptr); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c16b9c223f62..591e499e5ab5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -101,10 +101,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif // Activation ops - // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); - ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - ops.def( "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); @@ -116,39 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); #endif - ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); - ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); - - // Activation function used in GeGLU with `none` approximation. - ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); - - // Activation function used in GeGLU with `tanh` approximation. - ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - - // FATReLU implementation. - ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); - ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); - - ops.def( - "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " - "limit=7.0) " - "-> ()"); - ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul); - - // GELU implementation used in GPT-2. - ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_new", torch::kCUDA, &gelu_new); - - // Approximate GELU implementation. - ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); - - // Quick GELU implementation. - ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); - ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/setup.py b/setup.py index 8dea355da7c8..c3377cd1a8d9 100644 --- a/setup.py +++ b/setup.py @@ -674,6 +674,7 @@ def extract_precompiled_and_patch_package( with zipfile.ZipFile(wheel_path) as wheel: files_to_copy = [ "vllm/_C.abi3.so", + "vllm/_C_stable_libtorch.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", "vllm/_flashmla_extension_C.abi3.so", @@ -989,6 +990,8 @@ def _read_requirements(filename: str) -> list[str]: if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) + if _is_cuda() or _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch")) package_data = { "vllm": [ diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c2fcde4ab1cf..df469d9a1a96 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,6 +14,7 @@ # import custom ops, trigger op registration import vllm._C # noqa +import vllm._C_stable_libtorch # noqa from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml from vllm.utils.torch_utils import cuda_device_count_stateless diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a8a1d59f1bf0..16bbc0acbf45 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -37,6 +37,11 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) +try: + import vllm._C_stable_libtorch # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C_stable_libtorch with %r", e) + # import custom ops, trigger op registration try: import vllm._rocm_C # noqa: F401