Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c30d36b
[MoE] Add DeepSeek V3 router GEMM kernel from sglang
Feb 11, 2026
48bae28
match sgl structure
Feb 11, 2026
ef4549f
add the missing files
Feb 11, 2026
b863fe5
make dsv3 out work nicely
Feb 11, 2026
6385e0e
creating it
Feb 11, 2026
744bcc6
compile for sm100
Feb 11, 2026
dbde9d6
trying to get compilation working
Feb 11, 2026
d717911
torch --> at
Feb 11, 2026
0c0478f
working end to end
Feb 11, 2026
dd0a2b0
only use fp32 for trtllm nvfp4
Feb 11, 2026
ca866f4
remove nit
Feb 11, 2026
205ef51
cleanup
Feb 11, 2026
2623752
add comment abut why we need a setter fo _out_type
Feb 11, 2026
c9a1cad
cmake
Feb 11, 2026
0320b52
Merge remote-tracking branch 'origin/main' into use-sgl-gate-for-fp32…
Feb 17, 2026
945119d
Use ATen's getCurrentDeviceProperties() for SM version detection
Feb 17, 2026
1a82cd0
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 18, 2026
e7781de
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 18, 2026
931fef9
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 19, 2026
18ea1e8
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
mgoin Feb 19, 2026
1fbac3d
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 20, 2026
5f87204
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 20, 2026
0cca4f3
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 20, 2026
e2deb0c
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 22, 2026
880b6f0
Merge branch 'main' into use-sgl-gate-for-fp32-router-logits
robertgshaw2-redhat Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures")
endif()

# DeepSeek V3 router GEMM kernel - requires SM90+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS)
set(DSV3_ROUTER_GEMM_SRC
"csrc/moe/dsv3_router_gemm_entry.cu"
"csrc/moe/dsv3_router_gemm_float_out.cu"
"csrc/moe/dsv3_router_gemm_bf16_out.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_ROUTER_GEMM_SRC}"
CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
else()
message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
" (requires SM90+ and CUDA >= 12.0)")
endif()
endif()

message(STATUS "Enabling moe extension.")
Expand Down
291 changes: 291 additions & 0 deletions csrc/moe/dsv3_router_gemm_bf16_out.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include "dsv3_router_gemm_utils.h"

// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b,
float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}

// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec,
float* dst) {
__nv_bfloat16* bf16_ptr =
reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));

#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}

template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(
__nv_bfloat16* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations =
kHiddenDim / k_elems_per_k_iteration; // Total K iterations

// Initialize accumulators for all M rows
float acc[kNumTokens] = {};

// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps

// B matrix is in column-major order, so we can directly load a column for the
// n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;

// Pre-compute k_base values for each iteration to help compiler optimize
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif

// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];

// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);

// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);

// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(
mat_a + (m_idx * kHiddenDim) + k_base);

// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);

// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}

// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;

// Register for warp-level reduction results
float warp_result[kNumTokens];

#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}

// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];

// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);

// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}

__syncthreads();

// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;

// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}

// Write final result
out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_bf16_output<T, kBlockSize, VPT, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}

// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);

template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
Loading