[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup)#34302
[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup)#34302robertgshaw2-redhat merged 25 commits intomainfrom
Conversation
Port the optimized router GEMM kernel from sglang's sgl-kernel for
DeepSeek V3 MoE models. This kernel is specifically optimized for
small batch sizes (1-16 tokens) common in decode phase.
Key features:
- Computes output = mat_a @ mat_b.T for MoE routing
- Supports bfloat16 input with float32 or bfloat16 output
- Optimized for DSV3 dimensions: hidden_dim=7168, num_experts={256,384}
- Requires SM90+ (Hopper) GPUs and CUDA 12.0+
- Supports Programmatic Dependent Launch (PDL) via TRTLLM_ENABLE_PDL=1
Original kernel adapted from TensorRT-LLM's dsv3RouterGemm implementation.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
There was a problem hiding this comment.
Code Review
This pull request ports an optimized router GEMM kernel for DeepSeek V3 MoE models from sglang. The changes include the CUDA kernel implementation, build system integration, and PyTorch bindings. The kernel is highly specialized for specific model configurations and hardware (SM90+). My review focuses on the new CUDA kernel implementation. I've identified significant code duplication between the float32 and bfloat16 output kernels, which should be refactored for better maintainability. Additionally, there are missing error checks for CUDA API calls, which could lead to unhandled runtime errors.
csrc/moe/dsv3_router_gemm.cu
Outdated
| inline int getSMVersion() { | ||
| int device{-1}; | ||
| cudaGetDevice(&device); | ||
| int sm_major = 0; | ||
| int sm_minor = 0; | ||
| cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); | ||
| cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); | ||
| return sm_major * 10 + sm_minor; | ||
| } |
There was a problem hiding this comment.
The CUDA API calls cudaGetDevice and cudaDeviceGetAttribute can return errors, but their return values are not being checked. This could lead to silent failures or undefined behavior if an error occurs (e.g., no CUDA device is available). It's important to handle these potential errors by checking the cudaError_t return value.
For example:
cudaError_t err = cudaGetDevice(&device);
if (err != cudaSuccess) {
// Handle error
}Or using a macro for checking, which is a common practice in CUDA projects to reduce boilerplate. Other parts of the vLLM codebase use error checking macros for CUDA calls, and that practice should be followed here for consistency and robustness.
csrc/moe/dsv3_router_gemm.cu
Outdated
| template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, | ||
| int kHiddenDim> | ||
| __global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output( | ||
| float* out, T const* mat_a, T const* mat_b) { | ||
| // Each block handles one expert column | ||
| int const n_idx = blockIdx.x; | ||
| int const tid = threadIdx.x; | ||
| constexpr int kWarpSize = 32; | ||
| constexpr int kNumWarps = kBlockSize / kWarpSize; | ||
| constexpr int k_elems_per_k_iteration = VPT * kBlockSize; | ||
| constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; | ||
|
|
||
| // Initialize accumulators for all M rows | ||
| float acc[kNumTokens] = {}; | ||
|
|
||
| // Shared memory for warp-level reduction | ||
| __shared__ float sm_reduction[kNumTokens][kNumWarps]; | ||
|
|
||
| // B matrix is in column-major order | ||
| T const* b_col = mat_b + n_idx * kHiddenDim; | ||
|
|
||
| // Pre-compute k_base values | ||
| int k_bases[k_iterations]; | ||
| #pragma unroll | ||
| for (int ki = 0; ki < k_iterations; ki++) { | ||
| k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; | ||
| } | ||
|
|
||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.wait;"); | ||
| #endif | ||
|
|
||
| // Process the GEMM in chunks | ||
| for (int ki = 0; ki < k_iterations; ki++) { | ||
| int const k_base = k_bases[ki]; | ||
|
|
||
| // Load B matrix values using vector load | ||
| uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base); | ||
|
|
||
| // Convert B values to float | ||
| float b_float[VPT]; | ||
| bf16_uint4_to_float8<VPT>(b_vec, b_float); | ||
|
|
||
| #pragma unroll | ||
| for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { | ||
| uint4 a_vec = *reinterpret_cast<uint4 const*>( | ||
| mat_a + (m_idx * kHiddenDim) + k_base); | ||
|
|
||
| float a_float[VPT]; | ||
| bf16_uint4_to_float8<VPT>(a_vec, a_float); | ||
|
|
||
| #pragma unroll | ||
| for (int k = 0; k < VPT; k++) { | ||
| acc[m_idx] += a_float[k] * b_float[k]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Warp-level reduction | ||
| int const warpId = tid / 32; | ||
| int const laneId = tid % 32; | ||
|
|
||
| float warp_result[kNumTokens]; | ||
| #pragma unroll | ||
| for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { | ||
| warp_result[m_idx] = acc[m_idx]; | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int m = 0; m < kNumTokens; m++) { | ||
| float sum = warp_result[m]; | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 16); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 8); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 4); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 2); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 1); | ||
|
|
||
| if (laneId == 0) { | ||
| sm_reduction[m][warpId] = sum; | ||
| } | ||
| } | ||
|
|
||
| __syncthreads(); | ||
|
|
||
| // Final reduction across warps | ||
| if (tid == 0) { | ||
| #pragma unroll | ||
| for (int m = 0; m < kNumTokens; m++) { | ||
| float final_sum = 0.0f; | ||
| #pragma unroll | ||
| for (int w = 0; w < kNumWarps; w++) { | ||
| final_sum += sm_reduction[m][w]; | ||
| } | ||
| out[m * kNumExperts + n_idx] = final_sum; | ||
| } | ||
| } | ||
|
|
||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.launch_dependents;"); | ||
| #endif | ||
| } | ||
|
|
||
| // Router GEMM kernel with bfloat16 output | ||
| template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, | ||
| int kHiddenDim> | ||
| __global__ __launch_bounds__(128, 1) void router_gemm_kernel_bf16_output( | ||
| __nv_bfloat16* out, T const* mat_a, T const* mat_b) { | ||
| int const n_idx = blockIdx.x; | ||
| int const tid = threadIdx.x; | ||
| constexpr int kWarpSize = 32; | ||
| constexpr int kNumWarps = kBlockSize / kWarpSize; | ||
| constexpr int k_elems_per_k_iteration = VPT * kBlockSize; | ||
| constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; | ||
|
|
||
| float acc[kNumTokens] = {}; | ||
| __shared__ float sm_reduction[kNumTokens][kNumWarps]; | ||
|
|
||
| T const* b_col = mat_b + n_idx * kHiddenDim; | ||
|
|
||
| int k_bases[k_iterations]; | ||
| #pragma unroll | ||
| for (int ki = 0; ki < k_iterations; ki++) { | ||
| k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; | ||
| } | ||
|
|
||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.wait;"); | ||
| #endif | ||
|
|
||
| for (int ki = 0; ki < k_iterations; ki++) { | ||
| int const k_base = k_bases[ki]; | ||
| uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base); | ||
|
|
||
| float b_float[VPT]; | ||
| bf16_uint4_to_float8<VPT>(b_vec, b_float); | ||
|
|
||
| #pragma unroll | ||
| for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { | ||
| uint4 a_vec = *reinterpret_cast<uint4 const*>( | ||
| mat_a + (m_idx * kHiddenDim) + k_base); | ||
|
|
||
| float a_float[VPT]; | ||
| bf16_uint4_to_float8<VPT>(a_vec, a_float); | ||
|
|
||
| #pragma unroll | ||
| for (int k = 0; k < VPT; k++) { | ||
| acc[m_idx] += a_float[k] * b_float[k]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| int const warpId = tid / 32; | ||
| int const laneId = tid % 32; | ||
|
|
||
| float warp_result[kNumTokens]; | ||
| #pragma unroll | ||
| for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { | ||
| warp_result[m_idx] = acc[m_idx]; | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int m = 0; m < kNumTokens; m++) { | ||
| float sum = warp_result[m]; | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 16); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 8); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 4); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 2); | ||
| sum += __shfl_xor_sync(0xffffffff, sum, 1); | ||
|
|
||
| if (laneId == 0) { | ||
| sm_reduction[m][warpId] = sum; | ||
| } | ||
| } | ||
|
|
||
| __syncthreads(); | ||
|
|
||
| if (tid == 0) { | ||
| #pragma unroll | ||
| for (int m = 0; m < kNumTokens; m++) { | ||
| float final_sum = 0.0f; | ||
| #pragma unroll | ||
| for (int w = 0; w < kNumWarps; w++) { | ||
| final_sum += sm_reduction[m][w]; | ||
| } | ||
| out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum); | ||
| } | ||
| } | ||
|
|
||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.launch_dependents;"); | ||
| #endif | ||
| } |
There was a problem hiding this comment.
The kernels router_gemm_kernel_float_output and router_gemm_kernel_bf16_output are nearly identical, with the only difference being the output data type and the final store operation. This large amount of duplicated code increases maintenance overhead and the risk of introducing inconsistencies.
To improve maintainability, these two kernels should be refactored into a single templated kernel. You can introduce a helper struct OutputWriter templated on the output type to handle the final store operation.
Here's a sketch of the proposed refactoring:
template <typename T_out>
struct OutputWriter;
template <>
struct OutputWriter<float> {
__device__ __forceinline__ static void write(float* out, int index,
float value) {
out[index] = value;
}
};
template <>
struct OutputWriter<__nv_bfloat16> {
__device__ __forceinline__ static void write(__nv_bfloat16* out, int index,
float value) {
out[index] = __float2bfloat16(value);
}
};
template <typename T, typename T_out, int kBlockSize, int VPT, int kNumTokens,
int kNumExperts, int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel(
T_out* out, T const* mat_a, T const* mat_b) {
// ... common kernel logic ...
// In the final reduction section
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
OutputWriter<T_out>::write(out, m * kNumExperts + n_idx, final_sum);
}
}
// ... rest of common kernel logic ...
}Then, invokeRouterGemmFloatOutput and invokeRouterGemmBf16Output can call this unified router_gemm_kernel with the appropriate output type (float or __nv_bfloat16). This will eliminate about 100 lines of redundant code.
|
Can we use the router gemm interface already present in flashinfer? |
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
cc @pavanimajety - looks like these dont support SM90. Any idea of the plan here? |
|
TODO:
|
CMakeLists.txt
Outdated
| endif() | ||
|
|
||
| # DeepSeek V3 router GEMM kernel - requires SM90+ | ||
| cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") |
There was a problem hiding this comment.
This isn't compatible with CUDA 13 and missing blackwell ultra, should be something like
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
| def _set_allow_dsv3_router_gemm(self) -> None: | ||
| self.allow_dsv3_router_gemm = ( | ||
| current_platform.is_cuda() | ||
| and current_platform.has_device_capability((9, 0)) |
There was a problem hiding this comment.
This should be current_platform.is_device_capability(90) and current_platform.is_device_capability_family(100) since we aren't supporting sm120
There was a problem hiding this comment.
It also looks like you need to check against supported n_experts since I only see instantiations for 256 or 384 experts
csrc/moe/dsv3_router_gemm_entry.cu
Outdated
| "output must be float32 or bf16"); | ||
|
|
||
| auto const sm = getSMVersion(); | ||
| TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); |
There was a problem hiding this comment.
Do you know if this would work on SM120? Better to be explicit if we don't know
Signed-off-by: Robert Shaw <robshaw@redhat.com>
|
Hey all, FYI - there's a flashinfer PR ready that removes the restriction for non SM100 in case we want to switch to the flashinfer implementation - flashinfer-ai/flashinfer#2576 |
|
I think we still expect these kernels to improve and evolve(new HW arch) in FI, it would be great to consider invoking them directly with Flashinfer (perhaps with 0.6.5 update). Not blocking for this PR though, I'll keep track |
|
@xinli-sw - sounds good. We can add it once FI once flashinfer hits 0.6.5 |
|
this commit somehow breaks the model loading on spark thanks the log |
|
@robertgshaw2-redhat - this is a second DSV3-related PR that breaks vLLM on DGX Spark (and other sm12x). I believe you need to guard it properly. @mgoin, @johnnynunez - FYI. EDIT: the first PR was this one: #34758 |
|
Thank you for reporting @eugr @stavinsky and sorry for the disruption. I should have a fix here #35123 |
|
thanks, sorry for the issues |
|
always happy to help, guys |
|
no problem, it's a big project with a very wide hardware support. Stuff happens. |
I have to say, I dont quite guy why this did not break on SM89 where we run a lot of tests. |
|
@robertgshaw2-redhat It is because the CI image is built with a wide ranging TORCH_CUDA_ARCH_LIST, basically including all source files and cases across CUDA arches. You would only run into this issue if you build for just your arch i.e. TORCH_CUDA_ARCH_LIST=12.0 since you wouldn't build those source files. |
…llm-project#34302) Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
…llm-project#34302) Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
…llm-project#34302) Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…llm-project#34302) Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
…llm-project#34302) Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Port the optimized router GEMM kernel from sglang's sgl-kernel for DeepSeek V3 MoE models. This kernel is specifically optimized for small batch sizes (1-16 tokens) common in decode phase. The kernel is originally adapted from TRTLLM.
Key features:
Original kernel adapted from TensorRT-LLM's dsv3RouterGemm implementation.
5.5% E2E Speedup for Batch 1 Decode.
Purpose
Test Plan
eval: lm_eval \ --model local-completions \ --tasks gsm8k \ --model_args "model={{MODEL}},base_url=http://localhost:{{PORT}}/v1/completions,num_concurrent=10,tokenized_requests=False" --limit 100^ run with concurrency 10 to hit the low batch size
benchmark: vllm bench serve \ --port {{PORT}} \ --model {{MODEL}} \ --dataset-name random \ --input-len 2 \ --output-len 100 \ --max-concurrency 1 \ --num-prompts 10 \ --seed $(date +%s) \ --temperature 0.0Test Result
local-completions ({'model': 'nvidia/DeepSeek-V3.1-NVFP4', 'base_url': 'http://localhost:8001/v1/completions', 'num_concurrent': 10, 'tokenized_requests': False}), gen_kwargs: ({}), limit: 100.0, num_fewshot: None, batch_size: 1 |Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr| |-----|------:|----------------|-----:|-----------|---|----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ | 0.97|± |0.0171| | | |strict-match | 5|exact_match|↑ | 0.97|± |0.0171|Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.