-
-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[1/n] Migrate activation kernels to libtorch stable ABI #30908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -541,7 +541,25 @@ function (define_extension_target MOD_NAME) | |
| if (ARG_LANGUAGE STREQUAL "CUDA") | ||
| target_link_libraries(${MOD_NAME} PRIVATE torch CUDA::cudart CUDA::cuda_driver ${ARG_LIBRARIES}) | ||
| else() | ||
| target_link_libraries(${MOD_NAME} PRIVATE torch ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) | ||
| # Link against PyTorch's bundled libtorch_hip.so (for DeviceGuard registration) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a RoCM machine, I found that I needed these changes to make sure that vllm was linking to these two .so from torch (in particular the libamd.so hip headers packaged by torch) otherwise the stable DeviceGuard would not work correctly It seemed that 2 separate hipContexts were created by the raw hip calls that vllm did (which were from a hip header from elsewhere) and the libtorch shims that called raw hip APIs, which used the hip headers packaged by torch
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mikaylagawarecki I'm not sure what's going on here. Do all extensions need this code? Or is vLLM doing something weird? |
||
| # and libamdhip64.so (to share a single HIP runtime with PyTorch). | ||
| find_library(TORCH_HIP_LIBRARY torch_hip PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) | ||
| find_library(TORCH_AMDHIP64_LIBRARY amdhip64 PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) | ||
|
|
||
| set(_hip_libs) | ||
| if (TORCH_HIP_LIBRARY) | ||
| list(APPEND _hip_libs ${TORCH_HIP_LIBRARY}) | ||
| endif() | ||
| if (TORCH_AMDHIP64_LIBRARY) | ||
| list(APPEND _hip_libs ${TORCH_AMDHIP64_LIBRARY}) | ||
| # Ensure PyTorch's bundled libamdhip64.so is found at runtime, not system ROCm's. | ||
| set(_torch_lib_dir "${TORCH_INSTALL_PREFIX}/lib") | ||
| set_target_properties(${MOD_NAME} PROPERTIES | ||
| BUILD_RPATH "${_torch_lib_dir}" | ||
| INSTALL_RPATH "${_torch_lib_dir}") | ||
| endif() | ||
|
|
||
| target_link_libraries(${MOD_NAME} PRIVATE torch ${_hip_libs} ${TORCH_LIBRARIES} ${ARG_LIBRARIES}) | ||
| endif() | ||
|
|
||
| install(TARGETS ${MOD_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} COMPONENT ${MOD_NAME}) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,8 +165,6 @@ void persistent_masked_m_silu_mul_quant( | |
| at::Tensor& y_s, // (E, T, H//group_size) [OUT] | ||
| bool use_ue8m0); | ||
|
|
||
| void mul_and_silu(torch::Tensor& out, torch::Tensor& input); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it might be confusing why this is deleted -- I kept the other declarations because the cpu torch_bindings.cpp includes ops.h. All the other ops are also defined for CPU as well, but this op isn't defined for CPU so its declaration is deleted |
||
|
|
||
| void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); | ||
|
|
||
| void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| /* | ||
| * Stable ABI compatible dispatch utilities for vLLM. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Whenever you add "stable abi" in a comment you probably want to call it "libtorch stable abi" to clarify |
||
| * Adapted from dispatch_utils.h to use PyTorch's header-only (THO_*) macros | ||
| * instead of the ATen (AT_*) macros. | ||
| * | ||
| * These macros use: | ||
| * - THO_DISPATCH_SWITCH instead of AT_DISPATCH_SWITCH | ||
| * - THO_DISPATCH_CASE instead of AT_DISPATCH_CASE | ||
| * - torch::headeronly::ScalarType instead of at::ScalarType | ||
| * | ||
| * Add more macros here as needed when migrating additional kernels. | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <torch/headeronly/core/Dispatch.h> | ||
| #include <torch/headeronly/core/ScalarType.h> | ||
|
|
||
| #define VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(...) \ | ||
| THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \ | ||
| THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \ | ||
| THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__) | ||
|
|
||
| #define VLLM_STABLE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ | ||
| THO_DISPATCH_SWITCH(TYPE, NAME, \ | ||
| VLLM_STABLE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/csrc/stable/library.h> | ||
| #include <torch/csrc/stable/tensor.h> | ||
|
|
||
| // Gated activation functions (input: [..., 2*d] -> output: [..., d]) | ||
| void silu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); | ||
| void mul_and_silu(torch::stable::Tensor& out, torch::stable::Tensor& input); | ||
| void gelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input); | ||
| void gelu_tanh_and_mul(torch::stable::Tensor& out, | ||
| torch::stable::Tensor& input); | ||
| void fatrelu_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, | ||
| double threshold); | ||
| void swigluoai_and_mul(torch::stable::Tensor& out, torch::stable::Tensor& input, | ||
| double alpha, double limit); | ||
|
|
||
| // Element-wise activation functions (input: [..., d] -> output: [..., d]) | ||
| void gelu_new(torch::stable::Tensor& out, torch::stable::Tensor& input); | ||
| void gelu_fast(torch::stable::Tensor& out, torch::stable::Tensor& input); | ||
| void gelu_quick(torch::stable::Tensor& out, torch::stable::Tensor& input); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| #include "ops.h" | ||
| #include "core/registration.h" | ||
|
|
||
| #include <torch/csrc/stable/library.h> | ||
|
|
||
| // Register ops using STABLE_TORCH_LIBRARY for stable ABI compatibility. | ||
| // Note: We register under namespace "_C" so ops are accessible as | ||
| // torch.ops._C.<op_name> for compatibility with existing code. | ||
| STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { | ||
| // Activation ops | ||
| // Activation function used in SwiGLU. | ||
| m.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check... you are able to add tags to operators in the stable abi?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean adding |
||
|
|
||
| m.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); | ||
|
|
||
| // Activation function used in GeGLU with `none` approximation. | ||
| m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); | ||
|
|
||
| // Activation function used in GeGLU with `tanh` approximation. | ||
| m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); | ||
|
|
||
| // FATReLU implementation. | ||
| m.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); | ||
|
|
||
| m.def( | ||
| "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float " | ||
| "limit=7.0) -> ()"); | ||
|
|
||
| // GELU implementation used in GPT-2. | ||
| m.def("gelu_new(Tensor! out, Tensor input) -> ()"); | ||
|
|
||
| // Approximate GELU implementation. | ||
| m.def("gelu_fast(Tensor! out, Tensor input) -> ()"); | ||
|
|
||
| // Quick GELU implementation. | ||
| m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); | ||
| } | ||
|
|
||
| STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { | ||
| // Gated activations | ||
| m.impl("silu_and_mul", TORCH_BOX(&silu_and_mul)); | ||
| m.impl("mul_and_silu", TORCH_BOX(&mul_and_silu)); | ||
| m.impl("gelu_and_mul", TORCH_BOX(&gelu_and_mul)); | ||
| m.impl("gelu_tanh_and_mul", TORCH_BOX(&gelu_tanh_and_mul)); | ||
| m.impl("fatrelu_and_mul", TORCH_BOX(&fatrelu_and_mul)); | ||
| m.impl("swigluoai_and_mul", TORCH_BOX(&swigluoai_and_mul)); | ||
|
|
||
| // Element-wise activations | ||
| m.impl("gelu_new", TORCH_BOX(&gelu_new)); | ||
| m.impl("gelu_fast", TORCH_BOX(&gelu_fast)); | ||
| m.impl("gelu_quick", TORCH_BOX(&gelu_quick)); | ||
| } | ||
|
|
||
| REGISTER_EXTENSION(_C_stable_libtorch) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/csrc/inductor/aoti_torch/c/shim.h> | ||
| #include <torch/headeronly/util/Exception.h> | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <deque> | ||
| #include <mutex> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no namespacing at all, is that intentional?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // Device properties cache for stable ABI compatibility. | ||
| // Uses raw CUDA/HIP APIs instead of ATen functions. | ||
| // Thread-safe: each device's properties are queried exactly once. | ||
| inline std::deque<std::once_flag> device_prop_flags; | ||
| inline std::vector<cudaDeviceProp> device_prop_cache; | ||
| inline std::once_flag device_prop_vectors_init_flag; | ||
|
|
||
| inline void init_device_prop_vectors() { | ||
| int device_count; | ||
| cudaError_t err = cudaGetDeviceCount(&device_count); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " + | ||
| std::string(cudaGetErrorString(err))); | ||
| } | ||
| device_prop_flags.resize(device_count); | ||
| device_prop_cache.resize(device_count); | ||
| } | ||
|
|
||
| inline void init_device_prop(int device_index) { | ||
| cudaDeviceProp prop{}; | ||
| cudaError_t err = cudaGetDeviceProperties(&prop, device_index); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + | ||
| std::string(cudaGetErrorString(err))); | ||
| } | ||
| device_prop_cache[device_index] = prop; | ||
| } | ||
|
|
||
| inline cudaDeviceProp* get_device_prop() { | ||
| std::call_once(device_prop_vectors_init_flag, init_device_prop_vectors); | ||
| int device_index; | ||
| cudaError_t err = cudaGetDevice(&device_index); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK( | ||
| false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err))); | ||
| } | ||
| std::call_once(device_prop_flags[device_index], init_device_prop, | ||
| device_index); | ||
| return &device_prop_cache[device_index]; | ||
| } | ||
|
|
||
| // Utility to get the current CUDA stream for a given device using stable APIs. | ||
| // Returns a cudaStream_t for use in kernel launches. | ||
| inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { | ||
| void* stream_ptr = nullptr; | ||
| TORCH_ERROR_CODE_CHECK( | ||
| aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); | ||
| return reinterpret_cast<cudaStream_t>(stream_ptr); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment that "_C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION which is currently set to 2.10". (explicitly state what we get)