From 1c33d71655f8360ff94a9db004c6c259343b3e0f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 28 Jun 2025 20:36:16 -0700 Subject: [PATCH 1/8] migrate kernel --- sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 245 +++++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 sgl-kernel/csrc/gemm/dsv3_router_gemm.cu diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu new file mode 100644 index 000000000000..f9ce23cf4fa8 --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -0,0 +1,245 @@ +/* +* Adapted from +* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu +* +* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include + +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) +{ + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) +{ + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) + { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) +{ + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) + { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) + { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) + { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) + { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) + { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) + { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) + { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) + { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) + { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) + { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = final_sum; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) +{ + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + TORCH_CHECK(cudaLaunchKernelEx( + &config, router_gemm_kernel, output, mat_a, mat_b)); +} + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); From d411a091afd083b2fe4e9edcca9463977e49d77d Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 28 Jun 2025 20:59:05 -0700 Subject: [PATCH 2/8] add kernel launcher --- sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 64 ++++++++++++++++++------ 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index f9ce23cf4fa8..5d1921416479 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -196,50 +196,82 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_ &config, router_gemm_kernel, output, mat_a, mat_b)); } -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void tensorrt_llm::kernels::dsv3MinLatencyKernels::invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>( +template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>( float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + + +void router_gemm( + torch::Tensor& output, // [num_tokens, num_experts] + torch::Tensor& mat_a, // [num_tokens, hidden_dim] + torch::Tensor& mat_b, // [num_experts, hidden_dim] +) +{ + CHECK_INPUT(output); + CHECK_INPUT(mat_a); + CHECK_INPUT(mat_b); + + const int num_experts = mat_b.size(0); + const int num_tokens = mat_a.size(0); + const int hidden_dim = mat_a.size(1); + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(num_tokens <= 16, "Currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK(output.dtype() == torch::kBFloat16, "output must be bf16"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + invokeRouterGemm<__nv_bfloat16, num_tokens, num_experts, hidden_dim>( + static_cast<__nv_bfloat16*>(output.data_ptr()), + static_cast<__nv_bfloat16*>(mat_a.data_ptr()), + static_cast<__nv_bfloat16*>(mat_b.data_ptr()), + stream + ); +} From 09e8bb03329787dc399cfb8e53bafbaeadc0bac2 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 28 Jun 2025 22:36:06 -0700 Subject: [PATCH 3/8] add python bindings --- python/sglang/srt/models/deepseek_v2.py | 8 ++++++-- sgl-kernel/csrc/common_extension.cc | 3 +++ sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 5 ++--- sgl-kernel/include/sgl_kernel_ops.h | 5 +++++ sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/gemm.py | 11 +++++++++++ 6 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f1ab8c3e7e38..0fce94ba4d42 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -112,7 +112,7 @@ _is_cpu = is_cpu() if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 + from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2, router_gemm elif _is_cpu and _is_cpu_amx_available: pass else: @@ -224,7 +224,11 @@ def forward(self, hidden_states): True, # is_vnni ) - logits = F.linear(hidden_states, self.weight, None) + # Use cutomized router gemm for small batch sizes + if hidden_states.shape[0] <= 16: + logits = router_gemm(hidden_states, self.weight) + else: + logits = F.linear(hidden_states, self.weight, None) return logits diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d8629e6ab56b..91bf9edf03f6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -158,6 +158,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor expert_offsets, Tensor sf_offsets) -> ()"); m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + m.def("router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + m.impl("router_gemm", torch::kCUDA, &router_gemm); + /* * From csrc/moe */ diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index 5d1921416479..4757b7724ed7 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -247,8 +247,8 @@ template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>( void router_gemm( torch::Tensor& output, // [num_tokens, num_experts] - torch::Tensor& mat_a, // [num_tokens, hidden_dim] - torch::Tensor& mat_b, // [num_experts, hidden_dim] + const torch::Tensor& mat_a, // [num_tokens, hidden_dim] + const torch::Tensor& mat_b, // [num_experts, hidden_dim] ) { CHECK_INPUT(output); @@ -265,7 +265,6 @@ void router_gemm( TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); TORCH_CHECK(output.dtype() == torch::kBFloat16, "output must be bf16"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); invokeRouterGemm<__nv_bfloat16, num_tokens, num_experts, hidden_dim>( diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 30bc5dab5827..0c9f45fb4065 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -200,6 +200,11 @@ void bmm_fp8( at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); +void router_gemm( + torch::Tensor& output, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b +); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 4cba8cc7808d..4a8538ac9b90 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -39,6 +39,7 @@ int8_scaled_mm, qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, + router_gemm, scaled_fp4_experts_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 2946c8ae97ab..50ab0ad3cb09 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -258,6 +258,17 @@ def qserve_w4a8_per_group_gemm( ) return out_feats +def router_gemm( + hidden_states: torch.Tensor, + router_weights: torch.Tensor, +) -> torch.Tensor: + output = torch.empty(hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, dtype=hidden_states.dtype) + torch.ops.sgl_kernel.router_gemm( + output, + hidden_states, + router_weights, + ) + return output def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): output_tensor = torch.empty( From a00f12367b7babf572a1953ae384413219130885 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Jun 2025 00:30:19 -0700 Subject: [PATCH 4/8] fix lint --- sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 382 +++++++++++------------ sgl-kernel/include/sgl_kernel_ops.h | 6 +- sgl-kernel/python/sgl_kernel/gemm.py | 9 +- 3 files changed, 191 insertions(+), 206 deletions(-) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index 4757b7724ed7..022cb8f0c2a2 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -1,21 +1,21 @@ /* -* Adapted from -* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu -* -* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -26,251 +26,233 @@ #include "cuda_runtime.h" #include "utils.h" - // Custom FMA implementation using PTX assembly instructions -__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) -{ - asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" - : "=l"(reinterpret_cast(d)) - : "l"(reinterpret_cast(a)), "l"(reinterpret_cast(b)), - "l"(reinterpret_cast(c))); +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); } // Convert 8 bfloat16 values from a uint4 to float array - optimized conversion template -__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) -{ - __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); #pragma unroll - for (int i = 0; i < VPT; i++) - { - dst[i] = __bfloat162float(bf16_ptr[i]); - } + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } } template -__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) -{ - // Each block handles one expert column - int const n_idx = blockIdx.x; - int const tid = threadIdx.x; - constexpr int kWarpSize = 32; - constexpr int kNumWarps = kBlockSize / kWarpSize; - // Constants for this kernel - constexpr int k_elems_per_k_iteration = VPT * kBlockSize; - constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations - - // Initialize accumulators for all M rows - float acc[kNumTokens] = {}; - - // Shared memory for warp-level reduction - __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps - - // B matrix is in column-major order, so we can directly load a column for the n_idx expert - T const* b_col = mat_b + n_idx * kHiddenDim; - - // Pre-compute k_base values for each iteration to help compiler optimize - // int k_bases[k_iterations]; - int k_bases[k_iterations]; +__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; #pragma unroll - for (int ki = 0; ki < k_iterations; ki++) - { - k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; - } + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); + asm volatile("griddepcontrol.wait;"); #endif - // Process the GEMM in chunks - for (int ki = 0; ki < k_iterations; ki++) - { - int const k_base = k_bases[ki]; + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; - // Load B matrix values using vector load (8 bf16 values) - uint4 b_vec = *reinterpret_cast(b_col + k_base); + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); - // Convert B values to float - float b_float[VPT]; - bf16_uint4_to_float8(b_vec, b_float); + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); // Process each token #pragma unroll - for (int m_idx = 0; m_idx < kNumTokens; m_idx++) - { - // Load both rows of A matrix using vector loads - uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); - // Convert A values to float - float a_float[VPT]; - bf16_uint4_to_float8(a_vec, a_float); + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); // Process elements in this chunk #pragma unroll - for (int k = 0; k < VPT; k++) - { - float a = a_float[k]; - float b = b_float[k]; - acc[m_idx] += a * b; - } - } + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } } + } - // Perform warp-level reduction - int const warpSize = 32; - int const warpId = tid / warpSize; - int const laneId = tid % warpSize; + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; - // Register for warp-level reduction results - float warp_result[kNumTokens]; + // Register for warp-level reduction results + float warp_result[kNumTokens]; #pragma unroll - for (int m_idx = 0; m_idx < kNumTokens; m_idx++) - { - warp_result[m_idx] = acc[m_idx]; - } + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } // Perform warp-level reduction using optimized butterfly pattern #pragma unroll - for (int m = 0; m < kNumTokens; m++) - { - float sum = warp_result[m]; - - // Butterfly reduction pattern - sum += __shfl_xor_sync(0xffffffff, sum, 16); - sum += __shfl_xor_sync(0xffffffff, sum, 8); - sum += __shfl_xor_sync(0xffffffff, sum, 4); - sum += __shfl_xor_sync(0xffffffff, sum, 2); - sum += __shfl_xor_sync(0xffffffff, sum, 1); - - // Only the first thread in each warp stores to shared memory - if (laneId == 0) - { - sm_reduction[m][warpId] = sum; - } + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; } + } - __syncthreads(); + __syncthreads(); - // Final reduction across warps (only first thread) - if (tid == 0) - { + // Final reduction across warps (only first thread) + if (tid == 0) { #pragma unroll - for (int m = 0; m < kNumTokens; m++) - { - float final_sum = 0.0f; + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; // Sum across the kNumWarps #pragma unroll - for (int w = 0; w < kNumWarps; w++) - { - final_sum += sm_reduction[m][w]; - } - - // Write final result - out[m * kNumExperts + n_idx] = final_sum; - } + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = final_sum; } + } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + asm volatile("griddepcontrol.launch_dependents;"); #endif } template -void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) -{ - constexpr int VPT = 16 / sizeof(T); - constexpr int kBlockSize = 128; - cudaLaunchConfig_t config; - config.gridDim = kNumExperts; - config.blockDim = kBlockSize; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - TORCH_CHECK(cudaLaunchKernelEx( - &config, router_gemm_kernel, output, mat_a, mat_b)); +void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + TORCH_CHECK(cudaLaunchKernelEx( + &config, router_gemm_kernel, output, mat_a, mat_b)); } -template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); - -template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>( - float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void +invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); void router_gemm( - torch::Tensor& output, // [num_tokens, num_experts] - const torch::Tensor& mat_a, // [num_tokens, hidden_dim] - const torch::Tensor& mat_b, // [num_experts, hidden_dim] -) -{ - CHECK_INPUT(output); - CHECK_INPUT(mat_a); - CHECK_INPUT(mat_b); - - const int num_experts = mat_b.size(0); - const int num_tokens = mat_a.size(0); - const int hidden_dim = mat_a.size(1); - - TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); - TORCH_CHECK(num_tokens <= 16, "Currently num_tokens must be less than or equal to 16 for router_gemm"); - TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); - TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); - TORCH_CHECK(output.dtype() == torch::kBFloat16, "output must be bf16"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - invokeRouterGemm<__nv_bfloat16, num_tokens, num_experts, hidden_dim>( - static_cast<__nv_bfloat16*>(output.data_ptr()), - static_cast<__nv_bfloat16*>(mat_a.data_ptr()), - static_cast<__nv_bfloat16*>(mat_b.data_ptr()), - stream - ); + torch::Tensor& output, // [num_tokens, num_experts] + const torch::Tensor& mat_a, // [num_tokens, hidden_dim] + const torch::Tensor& mat_b, // [num_experts, hidden_dim] +) { + CHECK_INPUT(output); + CHECK_INPUT(mat_a); + CHECK_INPUT(mat_b); + + const int num_experts = mat_b.size(0); + const int num_tokens = mat_a.size(0); + const int hidden_dim = mat_a.size(1); + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(num_tokens <= 16, "Currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK(output.dtype() == torch::kBFloat16, "output must be bf16"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + invokeRouterGemm<__nv_bfloat16, num_tokens, num_experts, hidden_dim>( + static_cast<__nv_bfloat16*>(output.data_ptr()), + static_cast<__nv_bfloat16*>(mat_a.data_ptr()), + static_cast<__nv_bfloat16*>(mat_b.data_ptr()), + stream); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 0c9f45fb4065..78470f431d25 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -200,11 +200,7 @@ void bmm_fp8( at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); -void router_gemm( - torch::Tensor& output, - const torch::Tensor& mat_a, - const torch::Tensor& mat_b -); +void router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 50ab0ad3cb09..04a4b4d40716 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -258,11 +258,17 @@ def qserve_w4a8_per_group_gemm( ) return out_feats + def router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, ) -> torch.Tensor: - output = torch.empty(hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, dtype=hidden_states.dtype) + output = torch.empty( + hidden_states.shape[0], + router_weights.shape[0], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) torch.ops.sgl_kernel.router_gemm( output, hidden_states, @@ -270,6 +276,7 @@ def router_gemm( ) return output + def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): output_tensor = torch.empty( output_tensor_shape, From 3a62bd19b7a93dae93651d195668371da2d0229b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Jun 2025 01:58:42 -0700 Subject: [PATCH 5/8] restore deepseek_v2.py --- python/sglang/srt/models/deepseek_v2.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0fce94ba4d42..f1ab8c3e7e38 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -112,7 +112,7 @@ _is_cpu = is_cpu() if _is_cuda: - from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2, router_gemm + from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 elif _is_cpu and _is_cpu_amx_available: pass else: @@ -224,11 +224,7 @@ def forward(self, hidden_states): True, # is_vnni ) - # Use cutomized router gemm for small batch sizes - if hidden_states.shape[0] <= 16: - logits = router_gemm(hidden_states, self.weight) - else: - logits = F.linear(hidden_states, self.weight, None) + logits = F.linear(hidden_states, self.weight, None) return logits From 4b2541ed8c02e6c22c6885a23be91552b4418268 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Jun 2025 15:56:20 -0700 Subject: [PATCH 6/8] Fix compile bug --- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/csrc/common_extension.cc | 4 +- sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 56 ++++++++++++++++++------ sgl-kernel/include/sgl_kernel_ops.h | 2 +- sgl-kernel/python/sgl_kernel/__init__.py | 2 +- sgl-kernel/python/sgl_kernel/gemm.py | 6 +-- 6 files changed, 51 insertions(+), 20 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 72f3c069998d..739b2b1c2277 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -222,6 +222,7 @@ set(SOURCES "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" "csrc/gemm/dsv3_fused_a_gemm.cu" + "csrc/gemm/dsv3_router_gemm.cu" "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 91bf9edf03f6..941d0f836095 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -158,8 +158,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor expert_offsets, Tensor sf_offsets) -> ()"); m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); - m.def("router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - m.impl("router_gemm", torch::kCUDA, &router_gemm); + m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); /* * From csrc/moe diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index 022cb8f0c2a2..682021e581c5 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -1,6 +1,7 @@ /* * Adapted from * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp * * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * @@ -229,30 +230,59 @@ invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __n template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -void router_gemm( +template +struct LoopUnroller { + static void + unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll(num_tokens, output, input, weights, stream); + } + } +}; + +template +struct LoopUnroller { + static void + unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm( torch::Tensor& output, // [num_tokens, num_experts] const torch::Tensor& mat_a, // [num_tokens, hidden_dim] - const torch::Tensor& mat_b, // [num_experts, hidden_dim] + const torch::Tensor& mat_b // [num_experts, hidden_dim] ) { - CHECK_INPUT(output); - CHECK_INPUT(mat_a); - CHECK_INPUT(mat_b); + TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); - const int num_experts = mat_b.size(0); const int num_tokens = mat_a.size(0); - const int hidden_dim = mat_a.size(1); + constexpr int num_experts = 256; + constexpr int hidden_dim = 7168; TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); - TORCH_CHECK(num_tokens <= 16, "Currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); + TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); + TORCH_CHECK( + num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); - TORCH_CHECK(output.dtype() == torch::kBFloat16, "output must be bf16"); + TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32"); + + auto const sm = getSMVersion(); + TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - invokeRouterGemm<__nv_bfloat16, num_tokens, num_experts, hidden_dim>( - static_cast<__nv_bfloat16*>(output.data_ptr()), - static_cast<__nv_bfloat16*>(mat_a.data_ptr()), - static_cast<__nv_bfloat16*>(mat_b.data_ptr()), + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 78470f431d25..abea81e3f7e4 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -200,7 +200,7 @@ void bmm_fp8( at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); -void router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); +void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 4a8538ac9b90..2353004ce754 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -34,12 +34,12 @@ bmm_fp8, cutlass_scaled_fp4_mm, dsv3_fused_a_gemm, + dsv3_router_gemm, fp8_blockwise_scaled_mm, fp8_scaled_mm, int8_scaled_mm, qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, - router_gemm, scaled_fp4_experts_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 04a4b4d40716..6ec4ce78ab32 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -259,7 +259,7 @@ def qserve_w4a8_per_group_gemm( return out_feats -def router_gemm( +def dsv3_router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, ) -> torch.Tensor: @@ -267,9 +267,9 @@ def router_gemm( hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, - dtype=hidden_states.dtype, + dtype=torch.float32, ) - torch.ops.sgl_kernel.router_gemm( + torch.ops.sgl_kernel.dsv3_router_gemm( output, hidden_states, router_weights, From 5035b24b3cf251f0bfcf0d408d976fb115c81119 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Jun 2025 17:15:08 -0700 Subject: [PATCH 7/8] add benchmarks and tests --- .../benchmark/bench_dsv3_router_gemm.py | 56 +++++++++++++++++++ sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 4 +- sgl-kernel/tests/test_dsv3_router_gemm.py | 32 +++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_dsv3_router_gemm.py create mode 100644 sgl-kernel/tests/test_dsv3_router_gemm.py diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py new file mode 100644 index 000000000000..16b3143f0623 --- /dev/null +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -0,0 +1,56 @@ +import argparse + +import torch +import torch.nn.functional as F +import triton +import triton.testing +from sgl_kernel import dsv3_router_gemm + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[i + 1 for i in range(16)], + x_log=False, + line_arg="impl", + line_vals=["torch", "sgl-kernel"], + line_names=["torch", "dsv3_router_gemm"], + styles=[("blue", "-"), ("orange", "-")], + ylabel="TFLOPs", + plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", + args={}, + ) +) +def benchmark(num_tokens, impl): + # M: num_tokens, K: hidden_dim, N: num_experts + M, K, N = num_tokens, 7168, 256 + + mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() + mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() + + quantiles = [0.5, 0.2, 0.8] + + if impl == "torch": + + def runner(): + F.linear(mat_a, mat_b).to(torch.float32) + + elif impl == "sgl-kernel": + + def runner(): + dsv3_router_gemm(mat_a, mat_b) + + ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + + def tflops(t_ms): + flops = 2 * M * K * N + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = parser.parse_args() + + benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm") diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index 682021e581c5..53dec033a634 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -178,8 +178,8 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_ attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; - TORCH_CHECK(cudaLaunchKernelEx( - &config, router_gemm_kernel, output, mat_a, mat_b)); + cudaLaunchKernelEx( + &config, router_gemm_kernel, output, mat_a, mat_b); } template void diff --git a/sgl-kernel/tests/test_dsv3_router_gemm.py b/sgl-kernel/tests/test_dsv3_router_gemm.py new file mode 100644 index 000000000000..1b60bcf920d5 --- /dev/null +++ b/sgl-kernel/tests/test_dsv3_router_gemm.py @@ -0,0 +1,32 @@ +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import dsv3_router_gemm + + +@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)]) +def test_dsv3_router_gemm(num_tokens): + num_experts = 256 + hidden_dim = 7168 + + mat_a = torch.randn( + (num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda" + ).contiguous() + output = torch.empty( + (num_tokens, num_experts), dtype=torch.float32, device="cuda" + ).contiguous() + + ref = F.linear(mat_a, mat_b).to(torch.float32) + + output = dsv3_router_gemm(mat_a, mat_b) + + assert torch.allclose( + output, ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output mismatch with torch.nn.functional.linear reference" + + +if __name__ == "__main__": + pytest.main([__file__]) From 9440b50b6389b875680a55bab04e07e784f21cb8 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Jun 2025 17:25:17 -0700 Subject: [PATCH 8/8] fix --- sgl-kernel/csrc/gemm/dsv3_router_gemm.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu index 53dec033a634..410bbcefd3a6 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu @@ -21,8 +21,6 @@ #include #include -#include - #include "cuda_bf16.h" #include "cuda_runtime.h" #include "utils.h"