From 6b452e18df5d87975c168d33a84fdaa2faac9275 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 23 Feb 2026 18:00:47 +0000 Subject: [PATCH 1/2] [Build] Fix DSV3 kernels breaking _C and _moe_C on unsupported arches Signed-off-by: mgoin --- CMakeLists.txt | 1 - csrc/dsv3_fused_a_gemm.cu | 4 ++++ csrc/moe/dsv3_router_gemm_entry.cu | 5 +++++ csrc/moe/torch_bindings.cpp | 2 +- csrc/torch_bindings.cpp | 2 +- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a6f7f69468d1..55127a514f1f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -783,7 +783,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${DSV3_FUSED_A_GEMM_SRC}" CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) - list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1") message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") else() message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " diff --git a/csrc/dsv3_fused_a_gemm.cu b/csrc/dsv3_fused_a_gemm.cu index 5b8374303ad0..65dff9c84bab 100644 --- a/csrc/dsv3_fused_a_gemm.cu +++ b/csrc/dsv3_fused_a_gemm.cu @@ -745,3 +745,7 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, stream); } } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm); +} diff --git a/csrc/moe/dsv3_router_gemm_entry.cu b/csrc/moe/dsv3_router_gemm_entry.cu index 1ba97bd76406..1599d520c1ed 100644 --- a/csrc/moe/dsv3_router_gemm_entry.cu +++ b/csrc/moe/dsv3_router_gemm_entry.cu @@ -24,6 +24,7 @@ #include #include +#include "core/registration.h" #include "dsv3_router_gemm_utils.h" static constexpr int DEFAULT_NUM_EXPERTS = 256; @@ -161,3 +162,7 @@ void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts] } } } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("dsv3_router_gemm", &dsv3_router_gemm); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 22b00f20ad57..438599451452 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -127,7 +127,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); + // conditionally compiled so impl registration is in source file #endif } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c16b9c223f62..39b6bc98a843 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -242,7 +242,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). ops.def( "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm); + // conditionally compiled so impl registration is in source file // Quantized GEMM for AWQ. ops.def( From 68a40f7987a300530ad1f4e87f244da898b5fd99 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 23 Feb 2026 21:07:34 +0000 Subject: [PATCH 2/2] Fix build Signed-off-by: mgoin --- csrc/moe/dsv3_router_gemm_entry.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/moe/dsv3_router_gemm_entry.cu b/csrc/moe/dsv3_router_gemm_entry.cu index 1599d520c1ed..38fb681c2236 100644 --- a/csrc/moe/dsv3_router_gemm_entry.cu +++ b/csrc/moe/dsv3_router_gemm_entry.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include