Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -771,11 +771,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper only).
# SM 10.0/11.0/12.0 are excluded as this kernel is only intended for datacenter GPUs.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS)
set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu")
Expand All @@ -784,6 +785,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
list(APPEND VLLM_GPU_FLAGS "-DENABLE_DSV3_FUSED_A_GEMM=1")
add_compile_definitions(ENABLE_DSV3_FUSED_A_GEMM=1)
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ int64_t qr_max_size();
#endif

#ifndef USE_ROCM
#ifdef ENABLE_DSV3_FUSED_A_GEMM
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b);
#endif
#endif
3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
// Only available when the kernel is built for compatible architectures.
#ifdef ENABLE_DSV3_FUSED_A_GEMM
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
ops.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
#endif

// Quantized GEMM for AWQ.
ops.def(
Expand Down
45 changes: 31 additions & 14 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,22 +2789,39 @@ def sm100_cutlass_mla_get_workspace_size(
)


def dsv3_fused_a_gemm(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
"""DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
if hasattr(torch.ops._C, "dsv3_fused_a_gemm"):

Computes output = mat_a @ mat_b.T where:
mat_a: [num_tokens, 7168] row-major bf16 (hidden states)
mat_b: [7168, 2112] column-major bf16 (weight transposed)
output: [num_tokens, 2112] row-major bf16
def dsv3_fused_a_gemm(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
"""DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).

Optimized for the DeepSeek V2/V3 QKV A-projection at small batch sizes.
Requires SM 9.0+ (Hopper).
"""
torch.ops._C.dsv3_fused_a_gemm(output, mat_a, mat_b)
Computes output = mat_a @ mat_b.T where:
mat_a: [num_tokens, 7168] row-major bf16 (hidden states)
mat_b: [7168, 2112] column-major bf16 (weight transposed)
output: [num_tokens, 2112] row-major bf16

Optimized for the DeepSeek V2/V3 QKV A-projection at small batch sizes.
Requires SM 9.0+ (Hopper).
"""
torch.ops._C.dsv3_fused_a_gemm(output, mat_a, mat_b)

else:

def dsv3_fused_a_gemm(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
"""DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).

This kernel is only available for SM 9.0 (Hopper) architectures.
"""
raise RuntimeError(
"dsv3_fused_a_gemm is only supported on SM 9.0 (Hopper) architectures"
)


if hasattr(torch.ops._C, "weight_packed_linear"):
Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,17 +730,15 @@ def __init__(
)

# Check if the DeepSeek V3 fused A GEMM kernel can be used.
# This kernel supports PDL and is optimized for low batch size.
# This kernel is only supported on Hopper (SM 9.0) and above.
# Note: SM 10.0/11.0/12.0 (GeForce Blackwell) are not supported.
self._use_min_latency_gemm = (
hasattr(self, "weight")
and self.weight.dtype == torch.bfloat16
and self.weight.shape[0] == 2112
and self.weight.shape[1] == 7168
and current_platform.is_cuda()
and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
)
and current_platform.is_device_capability(90)
)

def forward(
Expand Down