diff --git a/CMakeLists.txt b/CMakeLists.txt index b00941a42a41..a6f7f69468d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1101,6 +1101,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Not building Marlin MOE kernels as no compatible archs found" " in CUDA target architectures") endif() + + # DeepSeek V3 router GEMM kernel - requires SM90+ + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS) + set(DSV3_ROUTER_GEMM_SRC + "csrc/moe/dsv3_router_gemm_entry.cu" + "csrc/moe/dsv3_router_gemm_float_out.cu" + "csrc/moe/dsv3_router_gemm_bf16_out.cu") + set_gencode_flags_for_srcs( + SRCS "${DSV3_ROUTER_GEMM_SRC}" + CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}") + list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}") + message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}") + else() + message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found" + " (requires SM90+ and CUDA >= 12.0)") + endif() endif() message(STATUS "Enabling moe extension.") diff --git a/csrc/moe/dsv3_router_gemm_bf16_out.cu b/csrc/moe/dsv3_router_gemm_bf16_out.cu new file mode 100644 index 000000000000..8c7000ccf352 --- /dev/null +++ b/csrc/moe/dsv3_router_gemm_bf16_out.cu @@ -0,0 +1,291 @@ +/* + * Adapted from SGLang's sgl-kernel implementation, which was adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "dsv3_router_gemm_utils.h" + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, + float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, + float* dst) { + __nv_bfloat16* bf16_ptr = + reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel_bf16_output( + __nv_bfloat16* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = + kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the + // n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast( + mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum); + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, + T const* mat_b, cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, + router_gemm_kernel_bf16_output, + output, mat_a, mat_b); +} + +// Template instantiations for DEFAULT_NUM_EXPERTS experts +template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +// Template instantiations for KIMI_K2_NUM_EXPERTS experts +template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/csrc/moe/dsv3_router_gemm_entry.cu b/csrc/moe/dsv3_router_gemm_entry.cu new file mode 100644 index 000000000000..1ba97bd76406 --- /dev/null +++ b/csrc/moe/dsv3_router_gemm_entry.cu @@ -0,0 +1,163 @@ +/* + * Adapted from SGLang's sgl-kernel implementation, which was adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "dsv3_router_gemm_utils.h" + +static constexpr int DEFAULT_NUM_EXPERTS = 256; +static constexpr int KIMI_K2_NUM_EXPERTS = 384; +static constexpr int DEFAULT_HIDDEN_DIM = 7168; + +template +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, + cudaStream_t stream); + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, + T const* mat_b, cudaStream_t stream); + +template +struct LoopUnroller { + static void unroll_float_output(int num_tokens, float* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts, + kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_float_output(num_tokens, output, input, + weights, stream); + } + } + + static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts, + kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_bf16_output(num_tokens, output, input, + weights, stream); + } + } +}; + +template +struct LoopUnroller { + static void unroll_float_output(int num_tokens, float* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>( + output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } + + static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>( + output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts] + const at::Tensor& mat_a, // [num_tokens, hidden_dim] + const at::Tensor& mat_b // [num_experts, hidden_dim] +) { + TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); + + const int num_tokens = mat_a.size(0); + const int num_experts = mat_b.size(0); + const int hidden_dim = mat_a.size(1); + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), + "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(hidden_dim == DEFAULT_HIDDEN_DIM, + "Expected hidden_dim=", DEFAULT_HIDDEN_DIM, + ", but got hidden_dim=", hidden_dim); + TORCH_CHECK( + num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS, + "Expected num_experts=", DEFAULT_NUM_EXPERTS, + " or num_experts=", KIMI_K2_NUM_EXPERTS, + ", but got num_experts=", num_experts); + TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, + "currently num_tokens must be less than or equal to 16 for " + "router_gemm"); + TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK(output.dtype() == at::kFloat || output.dtype() == at::kBFloat16, + "output must be float32 or bf16"); + + auto const sm = getSMVersion(); + TORCH_CHECK(sm >= 90 && sm <= 103, "required SM_103 >= CUDA ARCH >= SM_90"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (output.dtype() == at::kFloat) { + if (num_experts == DEFAULT_NUM_EXPERTS) { + LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>:: + unroll_float_output( + num_tokens, reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream); + } else if (num_experts == KIMI_K2_NUM_EXPERTS) { + LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>:: + unroll_float_output( + num_tokens, reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream); + } + } else if (output.dtype() == at::kBFloat16) { + if (num_experts == DEFAULT_NUM_EXPERTS) { + LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>:: + unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream); + } else if (num_experts == KIMI_K2_NUM_EXPERTS) { + LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>:: + unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream); + } + } +} diff --git a/csrc/moe/dsv3_router_gemm_float_out.cu b/csrc/moe/dsv3_router_gemm_float_out.cu new file mode 100644 index 000000000000..483eb1e023eb --- /dev/null +++ b/csrc/moe/dsv3_router_gemm_float_out.cu @@ -0,0 +1,291 @@ +/* + * Adapted from SGLang's sgl-kernel implementation, which was adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include "dsv3_router_gemm_utils.h" + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, + float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, + float* dst) { + __nv_bfloat16* bf16_ptr = + reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output( + float* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = + kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the + // n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast( + mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = final_sum; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, + cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, + router_gemm_kernel_float_output, + output, mat_a, mat_b); +} + +// Template instantiations for DEFAULT_NUM_EXPERTS experts +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +// Template instantiations for KIMI_K2_NUM_EXPERTS experts +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/csrc/moe/dsv3_router_gemm_utils.h b/csrc/moe/dsv3_router_gemm_utils.h new file mode 100644 index 000000000000..13b60d6be6a1 --- /dev/null +++ b/csrc/moe/dsv3_router_gemm_utils.h @@ -0,0 +1,43 @@ +/* + * Adapted from SGLang's sgl-kernel implementation, which was adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +inline int getSMVersion() { + auto* props = at::cuda::getCurrentDeviceProperties(); + return props->major * 10 + props->minor; +} + +inline bool getEnvEnablePDL() { + static std::once_flag flag; + static bool enablePDL = false; + std::call_once(flag, [&]() { + if (getSMVersion() >= 90) { + const char* env = std::getenv("TRTLLM_ENABLE_PDL"); + enablePDL = env && env[0] == '1' && env[1] == '\0'; + } + }); + return enablePDL; +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 89d54c47d654..b71db3569447 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -55,4 +55,15 @@ bool moe_permute_unpermute_supported(); void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, - torch::Tensor& output_tensor); \ No newline at end of file + torch::Tensor& output_tensor); + +#ifndef USE_ROCM +// DeepSeek V3 optimized router GEMM kernel for SM90+ +// Computes output = mat_a @ mat_b.T where: +// mat_a: [num_tokens, hidden_dim] in bf16 +// mat_b: [num_experts, hidden_dim] in bf16 +// output: [num_tokens, num_experts] in bf16 or fp32 +// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 +void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, + const torch::Tensor& mat_b); +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index fd9b8945e6d2..22b00f20ad57 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -124,6 +124,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, " "Tensor)"); m.impl("grouped_topk", torch::kCUDA, &grouped_topk); + + // DeepSeek V3 optimized router GEMM for SM90+ + m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); + m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); #endif } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 25d57d9aa5ee..e48ba6c997eb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2190,6 +2190,21 @@ def moe_wna16_gemm( ) +def dsv3_router_gemm( + hidden_states: torch.Tensor, + router_weight: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + router_weight.shape[0], + device=hidden_states.device, + dtype=output_dtype, + ) + torch.ops._moe_C.dsv3_router_gemm(output, hidden_states, router_weight) + return output + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3b3b7a1a37e5..768f4e20b34a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -221,6 +221,73 @@ def forward(self, x): return x +class DeepSeekV2Gate(ReplicatedLinear): + def __init__( + self, + hidden_size: int, + n_experts: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + assert quant_config is None + super().__init__( + hidden_size, + n_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + # Unquantized only, will be called "weight". + assert hasattr(self, "weight") + is_hopper_or_blackwell = current_platform.is_device_capability( + (9, 0) + ) or current_platform.is_device_capability_family(100) + SUPPORTED_NUM_EXPERTS = [256, 384] + SUPPORTED_HIDDEN_SIZES = [7168] + + self.allow_dsv3_router_gemm = ( + current_platform.is_cuda() + and is_hopper_or_blackwell + and n_experts in SUPPORTED_NUM_EXPERTS + and hidden_size in SUPPORTED_HIDDEN_SIZES + ) + + self._out_dtype: torch.dtype | None = None + + def set_out_dtype(self, out_dtype: torch.dtype) -> None: + """ + Set out dtype for the router logits. This is needed after + __init__, b/c we need to check if the trtllm kernel is + selected before we decide between bf16 and fp32. + """ + + if self._out_dtype is not None: + raise ValueError("out_dtype has already been set") + else: + self._out_dtype = out_dtype + + @property + def out_dtype(self) -> torch.dtype: + if self._out_dtype is None: + raise ValueError("out_dtype has not been set yet") + return self._out_dtype + + def forward( + self, + x: torch.Tensor, + ) -> tuple[torch.Tensor, None]: + """ + Use specialized GEMM for low batch size for DSV3 and KIMI. + """ + if self.allow_dsv3_router_gemm and x.shape[0] <= 16: + return ops.dsv3_router_gemm( + hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype + ), None + else: + return super().forward(x) + + class DeepseekV2MoE(nn.Module): def __init__( self, @@ -249,10 +316,9 @@ def __init__( "Only silu is supported for now." ) - self.gate = ReplicatedLinear( + self.gate = DeepSeekV2Gate( config.hidden_size, config.n_routed_experts, - bias=False, quant_config=None, prefix=f"{prefix}.gate", ) @@ -325,6 +391,13 @@ def __init__( else None, ) + # NOTE(rob): this is a hack until we finish off the PR for + # merging TRTLLM kernels into the MK framework. Then we can + # query the MonolithicMK for the expected router logits. + self.gate.set_out_dtype( + torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16 + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim)