-
-
Notifications
You must be signed in to change notification settings - Fork 18k
[6/n] Migrate activation kernels, gptq, gguf, non cutlass w8a8 to libtorch stable ABI #38757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfe9697
e3606eb
ec7ca63
364e676
f0b7eee
3500eae
5e1b090
ec7793c
a9b1ed2
2c13410
c41a8f8
f64dd26
690ee02
66206be
d0cf841
2fa10a1
7445da6
4e9bd85
149b9a6
aff8a2a
deea661
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -288,18 +288,13 @@ set(VLLM_EXT_SRC | |
| "csrc/attention/merge_attn_states.cu" | ||
| "csrc/attention/vertical_slash_index.cu" | ||
| "csrc/pos_encoding_kernels.cu" | ||
| "csrc/activation_kernels.cu" | ||
| "csrc/layernorm_kernels.cu" | ||
| "csrc/fused_qknorm_rope_kernel.cu" | ||
| "csrc/layernorm_quant_kernels.cu" | ||
| "csrc/sampler.cu" | ||
| "csrc/topk.cu" | ||
| "csrc/cuda_view.cu" | ||
| "csrc/quantization/gptq/q_gemm.cu" | ||
| "csrc/quantization/w8a8/int8/scaled_quant.cu" | ||
| "csrc/quantization/w8a8/fp8/common.cu" | ||
| "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" | ||
| "csrc/quantization/gguf/gguf_kernel.cu" | ||
| "csrc/quantization/activation_kernels.cu" | ||
| "csrc/cuda_utils_kernels.cu" | ||
| "csrc/custom_all_reduce.cu" | ||
|
|
@@ -339,7 +334,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| FetchContent_MakeAvailable(cutlass) | ||
|
|
||
| list(APPEND VLLM_EXT_SRC | ||
| "csrc/quantization/awq/gemm_kernels.cu" | ||
| "csrc/cutlass_extensions/common.cpp" | ||
| "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu") | ||
|
|
||
|
|
@@ -472,46 +466,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| " in CUDA target architectures") | ||
| endif() | ||
|
|
||
| # Only build AllSpark kernels if we are building for at least some compatible archs. | ||
| cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") | ||
| if (ALLSPARK_ARCHS) | ||
| set(ALLSPARK_SRCS | ||
| "csrc/quantization/gptq_allspark/allspark_repack.cu" | ||
| "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${ALLSPARK_SRCS}" | ||
| CUDA_ARCHS "${ALLSPARK_ARCHS}") | ||
| list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") | ||
| message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building AllSpark kernels as no compatible archs found" | ||
| " in CUDA target architectures") | ||
| endif() | ||
|
|
||
| # CUTLASS MLA Archs and flags | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) | ||
| cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") | ||
| else() | ||
| cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") | ||
| endif() | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) | ||
| set(SRCS | ||
| "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${MLA_ARCHS}") | ||
| list(APPEND VLLM_EXT_SRC "${SRCS}") | ||
| list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") | ||
| # Add MLA-specific include directories only to MLA source files | ||
| set_source_files_properties(${SRCS} | ||
| PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") | ||
| message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") | ||
| # clear MLA_ARCHS | ||
| set(MLA_ARCHS) | ||
| endif() | ||
|
|
||
| # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) | ||
| cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") | ||
|
|
@@ -539,24 +493,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| endif() | ||
| endif() | ||
|
|
||
| # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) | ||
| cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") | ||
| else() | ||
| cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") | ||
| endif() | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) | ||
| set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${DSV3_FUSED_A_GEMM_SRC}" | ||
| CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") | ||
| list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) | ||
| message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " | ||
| "in CUDA target architectures.") | ||
| endif() | ||
|
|
||
| # | ||
| # Machete kernels | ||
|
|
||
|
|
@@ -628,16 +564,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| endif() | ||
|
|
||
|
|
||
| # Hadacore kernels | ||
| cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") | ||
| if(HADACORE_ARCHS) | ||
| set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${HADACORE_ARCHS}") | ||
| list(APPEND VLLM_EXT_SRC "${SRCS}") | ||
| message(STATUS "Building hadacore") | ||
| endif() | ||
|
|
||
| # if CUDA endif | ||
| endif() | ||
|
|
@@ -669,31 +595,66 @@ define_extension_target( | |
| # Setting this variable sidesteps the issue by calling the driver directly. | ||
| target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) | ||
|
|
||
| # add OR VLLM_GPU_LANG STREQUAL "HIP" here once | ||
| # https://github.com/vllm-project/vllm/issues/35163 is resolved | ||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||
| if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") | ||
| # | ||
| # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) | ||
| # | ||
| set(VLLM_STABLE_EXT_SRC | ||
| "csrc/libtorch_stable/torch_bindings.cpp" | ||
| "csrc/cutlass_extensions/common.cpp" | ||
| "csrc/cuda_utils_kernels.cu" | ||
| "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" | ||
| "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" | ||
| "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu") | ||
| "csrc/libtorch_stable/activation_kernels.cu" | ||
| "csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu" | ||
| "csrc/libtorch_stable/quantization/w8a8/fp8/common.cu" | ||
| "csrc/libtorch_stable/quantization/gptq/q_gemm.cu" | ||
| "csrc/libtorch_stable/quantization/gguf/gguf_kernel.cu") | ||
|
|
||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||
| list(APPEND VLLM_STABLE_EXT_SRC | ||
| "csrc/cuda_utils_kernels.cu" | ||
| "csrc/cutlass_extensions/common.cpp" | ||
| "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" | ||
| "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" | ||
| "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu" | ||
| "csrc/libtorch_stable/permute_cols.cu" | ||
| "csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu" | ||
| "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu") | ||
| endif() | ||
| "csrc/libtorch_stable/quantization/w8a8/int8/per_token_group_quant.cu" | ||
| "csrc/libtorch_stable/quantization/awq/gemm_kernels.cu") | ||
|
|
||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${VLLM_STABLE_EXT_SRC}" | ||
| CUDA_ARCHS "${CUDA_ARCHS}") | ||
|
|
||
| # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) | ||
| cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") | ||
| else() | ||
| cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") | ||
| endif() | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS) | ||
| set(SRCS "csrc/libtorch_stable/dsv3_fused_a_gemm.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") | ||
| list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") | ||
| message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " | ||
| "in CUDA target architectures.") | ||
| endif() | ||
|
|
||
| # Only build AllSpark kernels if we are building for at least some compatible archs. | ||
| cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") | ||
| if (ALLSPARK_ARCHS) | ||
| set(SRCS | ||
| "csrc/libtorch_stable/quantization/gptq_allspark/allspark_repack.cu" | ||
| "csrc/libtorch_stable/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${ALLSPARK_ARCHS}") | ||
| list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") | ||
| message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building AllSpark kernels as no compatible archs found" | ||
| " in CUDA target architectures") | ||
| endif() | ||
|
|
||
| # | ||
|
|
@@ -989,6 +950,44 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| endif() | ||
| endif() | ||
|
|
||
| # CUTLASS MLA Archs and flags | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) | ||
| cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") | ||
| else() | ||
| cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") | ||
| endif() | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) | ||
| set(SRCS | ||
| "csrc/libtorch_stable/attention/mla/sm100_cutlass_mla_kernel.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${MLA_ARCHS}") | ||
| list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") | ||
| list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") | ||
| # Add MLA-specific include directories only to MLA source files | ||
| set_source_files_properties(${SRCS} | ||
| PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") | ||
| message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") | ||
| else() | ||
| message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") | ||
| # clear MLA_ARCHS | ||
| set(MLA_ARCHS) | ||
| endif() | ||
|
|
||
| # Hadacore kernels | ||
| cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") | ||
| if(HADACORE_ARCHS) | ||
| set(SRCS "csrc/libtorch_stable/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") | ||
| set_gencode_flags_for_srcs( | ||
| SRCS "${SRCS}" | ||
| CUDA_ARCHS "${HADACORE_ARCHS}") | ||
| list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") | ||
| message(STATUS "Building hadacore") | ||
| endif() | ||
|
|
||
| # if CUDA endif | ||
| endif() | ||
|
|
||
| message(STATUS "Enabling C_stable extension.") | ||
| define_extension_target( | ||
| _C_stable_libtorch | ||
|
|
@@ -1008,13 +1007,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |
| target_compile_definitions(_C_stable_libtorch PRIVATE | ||
| TORCH_TARGET_VERSION=0x020A000000000000ULL) | ||
|
|
||
| # Needed to use cuda APIs from C-shim | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE | ||
| USE_CUDA) | ||
| # Needed to use cuda/hip APIs from C-shim | ||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA) | ||
| # Needed by CUTLASS kernels | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE | ||
| CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) | ||
| elseif(VLLM_GPU_LANG STREQUAL "HIP") | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE USE_ROCM) | ||
| endif() | ||
|
|
||
| # Needed by CUTLASS kernels | ||
| target_compile_definitions(_C_stable_libtorch PRIVATE | ||
| CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) | ||
| # On ROCm, _C_stable_libtorch calls raw HIP APIs (e.g. hipGetDevice in | ||
| # get_device_prop()) which must resolve to the same libamdhip64.so that | ||
| # PyTorch uses. When PyTorch bundles its own copy (pip/conda wheels), | ||
| # the raw HIP calls would otherwise resolve to the system ROCm copy, | ||
| # initializing a second HIP runtime that corrupts device state (wrong | ||
| # device on DeviceGuard, core dumps on multi-GPU tests). | ||
| # | ||
| # If PyTorch doesn't bundle libamdhip64 (built from source against system | ||
| # ROCm), there is only one copy in the process and no action is needed — | ||
| # the HIP compiler already links the system libamdhip64 automatically. | ||
| if(VLLM_GPU_LANG STREQUAL "HIP") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to improve my understanding, this code basically specifically picks out the amdhip64 that pytorch bundles in order to have deterministic correct results and not get corrupted?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea concretely there seem to be two cases
|
||
| find_library(_STABLE_TORCH_AMDHIP64 amdhip64 | ||
| PATHS "${TORCH_INSTALL_PREFIX}/lib" NO_DEFAULT_PATH) | ||
| if(_STABLE_TORCH_AMDHIP64) | ||
| message(STATUS "Found PyTorch-bundled libamdhip64 at ${_STABLE_TORCH_AMDHIP64}") | ||
| target_link_libraries(_C_stable_libtorch PRIVATE ${_STABLE_TORCH_AMDHIP64}) | ||
| endif() | ||
| endif() | ||
| endif() | ||
|
|
||
| # | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,8 @@ | |
|
|
||
| #ifdef USE_ROCM | ||
| #include <hip/hip_runtime.h> | ||
| #include <hip/hip_bf16.h> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need these?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concrete error without these is Ah concretely it seems like there is a bug here https://github.com/pytorch/pytorch/blob/main/torch/headeronly/util/BFloat16.h#L15-L17, the We defined USE_ROCM for _C_stable_libtorch to expose some of the shims that are gated :/ |
||
| #include <hip/hip_fp16.h> | ||
| #else | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come this common.cpp needs to move to if CUDA?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: hmm wait just to confirm you were referring to csrc/cuda_utils_kernels.cu not common.cpp (which is correctly CUDA-only) right?
technically csrc/cuda_utils_kernels.cu should be shared cuda/rocm, but there's some issues with it being in sources for both extensions when building on rocm, so I want to punt that problem which will be solved when we fully migrate it out of _C
The error looks like