Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${DSV3_FUSED_A_GEMM_SRC}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1")
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
Expand Down
4 changes: 4 additions & 0 deletions csrc/dsv3_fused_a_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,7 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
stream);
}
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_fused_a_gemm", &dsv3_fused_a_gemm);
}
6 changes: 6 additions & 0 deletions csrc/moe/dsv3_router_gemm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include "core/registration.h"
Comment on lines 27 to +28
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of TORCH_LIBRARY_IMPL_EXPAND at the end of this file requires TORCH_LIBRARY_IMPL, which is defined in <torch/library.h>. While core/registration.h is included, it does not provide the underlying PyTorch macro definitions. Including <torch/all.h> (as seen in csrc/dsv3_fused_a_gemm.cu) or <torch/library.h> is necessary to avoid compilation errors on supported architectures.

#include <cuda_runtime.h>
#include <torch/all.h>

#include "core/registration.h"

#include "dsv3_router_gemm_utils.h"

static constexpr int DEFAULT_NUM_EXPERTS = 256;
Expand Down Expand Up @@ -161,3 +163,7 @@ void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
}
}
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("dsv3_router_gemm", &dsv3_router_gemm);
}
2 changes: 1 addition & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {

// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// conditionally compiled so impl registration is in source file
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// conditionally compiled so impl registration is in source file

// Quantized GEMM for AWQ.
ops.def(
Expand Down