Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
439f075
Enable _C_stable_libtorch for ROCm (HIP)
mikaylagawarecki Apr 1, 2026
b16f160
Move activation kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
131e5e0
[f/n] Migrate activation kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
60aeafa
Move INT8 quant kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
52a9bd6
[g/n] Migrate INT8 quant kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
3621698
Move FP8 quant kernel file from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
f4d32b2
[h/n] Migrate FP8 quant kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
4234544
Move GPTQ kernel files from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
502918e
[i/n] Migrate GPTQ kernels to torch stable ABI
mikaylagawarecki Apr 1, 2026
0592ba2
Move GGML/GGUF kernel files from csrc to csrc/libtorch_stable
mikaylagawarecki Apr 1, 2026
170c9d5
[j/n] Migrate GGML kernels to torch stable ABI
mikaylagawarecki Apr 2, 2026
47a8804
Missed converting torch::Tensor to orch::stable::Tensor in one locati…
cleonard530 May 14, 2026
e88fd9e
Add forward declarations in ops.h for silu_and_mul_clamp and silu_and…
cleonard530 May 14, 2026
fb33cd3
Move fused_silu_mul_block_quant from the legacy _C extension to the s…
cleonard530 May 15, 2026
796daf0
use TORCH_CHECK in dtype_fp8 when STD_TORCH_CHECK is undefined
cleonard530 May 15, 2026
71a471f
updated to use hip.hip_runtime.h header instead of cuda_runtime.h hea…
cleonard530 May 15, 2026
328a3ca
moved fused_silu_mul_block_quant.cu back to unstable libtorch for now…
cleonard530 May 18, 2026
f394c4d
Update torch_bindings.cpp
cleonard530 May 18, 2026
b7d1630
moved definition for TORCH_UTILS_CHECK to it's own header file csrc/t…
cleonard530 May 18, 2026
da74a3b
renamed csrc/torch_utils_check.h to the more general csrc/torch_utils.h
cleonard530 May 19, 2026
c511edf
Updated comments in torch_utils.h to make intention clearer.
cleonard530 May 19, 2026
a414194
forgot to change include header name from torch_utils_check.h to torc…
cleonard530 May 19, 2026
25caed6
fixed pre-commit linting errors
cleonard530 May 19, 2026
95c566f
Merge branch 'main' into new-stable-abi-phase6
Harry-Chen May 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,14 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu"
"csrc/quantization/w8a8/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflict around the "csrc/quantization/gguf/gguf_kernel.cu" line because "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" was newly added.

"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
Expand Down Expand Up @@ -628,33 +623,33 @@ define_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
# https://github.com/vllm-project/vllm/issues/35163 is resolved
if(VLLM_GPU_LANG STREQUAL "CUDA")
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
"csrc/libtorch_stable/quantization/gptq/q_gemm.cu"
"csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/libtorch_stable/permute_cols.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu"
"csrc/libtorch_stable/quantization/awq/gemm_kernels.cu")
endif()

if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
endif()

# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
Expand Down Expand Up @@ -1034,6 +1029,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building hadacore")
endif()

# if CUDA endif
endif()

message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
Expand All @@ -1053,13 +1051,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)

# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA)
# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
elseif(VLLM_GPU_LANG STREQUAL "HIP")
target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM)
endif()

# Needed by CUTLASS kernels
target_compile_definitions(_C_stable_libtorch PRIVATE
CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
# On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in
# get_device_prop()) which must resolve to the same libamdhip64.so that
# PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels),
# the raw HIP calls would otherwise resolve to the system ROCm copy,
# initializing a second HIP runtime that corrupts device state (wrong
# device on DeviceGuard, core dumps on multi-GPU tests).
#
# If PyTorch doesn't bundle libamdhip64 (built from source against system
# ROCm), there is only one copy in the process and no action is needed —
# the HIP compiler already links the system libamdhip64 automatically.
if(VLLM_GPU_LANG STREQUAL "HIP")
find_library(_STABLE_TORCH_AMDHIP64 amdhip64
PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH)
if(_STABLE_TORCH_AMDHIP64)
message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}")
target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64})
endif()
endif()
endif()

#
Expand Down
3 changes: 2 additions & 1 deletion csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "attention_generic.cuh"
#include "torch_utils.h"

#include <stdint.h>
#ifdef ENABLE_FP8
Expand Down Expand Up @@ -30,7 +31,7 @@ inline Fp8KVCacheDataType get_fp8_kv_cache_data_type(
} else if (dtype_str == "fp8_e5m2") {
return Fp8KVCacheDataType::kFp8E5M2;
}
TORCH_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
TORCH_UTILS_CHECK(false, "Unsupported fp8 kv cache data type: ", dtype_str);
}

// fp8 vector types for quantization of kv cache
Expand Down
2 changes: 2 additions & 0 deletions csrc/cuda_vec_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
Expand Down
6 changes: 2 additions & 4 deletions csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "torch_utils.h"

// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
Expand All @@ -8,13 +10,9 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif

#include "cute/layout.hpp"
Expand Down
Loading
Loading