diff --git a/CMakeLists.txt b/CMakeLists.txt index e59bfef6fc68..f0b1f53af831 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1240,16 +1240,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}") message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}") - - # DeepSeek V4 fused RMSNorm + router GEMV - same arch gating as DSV3. - set(DSV4_NORM_ROUTER_GEMM_SRC - "csrc/moe/dsv4_norm_router_gemm_entry.cu" - "csrc/moe/dsv4_norm_router_gemm_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${DSV4_NORM_ROUTER_GEMM_SRC}" - CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}") - list(APPEND VLLM_MOE_EXT_SRC "${DSV4_NORM_ROUTER_GEMM_SRC}") - message(STATUS "Building DSV4 norm+router GEMV kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}") else() message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found" " (requires SM90+ and CUDA >= 12.0)") diff --git a/benchmarks/kernels/benchmark_norm_router_gemm.py b/benchmarks/kernels/benchmark_norm_router_gemm.py deleted file mode 100644 index cd50e9159961..000000000000 --- a/benchmarks/kernels/benchmark_norm_router_gemm.py +++ /dev/null @@ -1,183 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark and correctness check for ``ops.dsv4_norm_router_gemm``. - -Two implementations are compared: - - 1. ``unfused`` — ``vllm_ops.rms_norm`` then ``ops.dsv3_router_gemm``, - i.e. the current vLLM hot path (two kernel launches). - 2. ``fused`` — ``ops.dsv4_norm_router_gemm``, the new single-kernel - fused path. - -Both produce ``(normed_x: bf16, router_logits: fp32)``. The correctness -check verifies that ``fused`` and ``unfused`` agree to within ~1 bf16 -ULP — that is the precision floor for this op. -""" - -import argparse - -import torch - -from vllm import _custom_ops as vllm_ops -from vllm.triton_utils import triton - -# The fused dsv4_norm_router_gemm kernel is templated only for DSV4-Pro -# (hidden_size=7168, num_experts=384). Other shapes fall back to the -# unfused path on the Python side (NormGatedLinear), so benchmark only -# the configuration that the fused kernel actually targets. -HIDDEN_SIZE = 7168 -NUM_EXPERTS_CHOICES = (384,) -RMS_EPS = 1e-6 - - -def unfused_norm_router_gemm( - x: torch.Tensor, - norm_weight: torch.Tensor, - gate_weight: torch.Tensor, - eps: float, -) -> tuple[torch.Tensor, torch.Tensor]: - # Call ``_C::rms_norm`` directly (mirroring ``_dsv4_pro_norm_gate``'s - # fallback path) so the benchmarked baseline doesn't inherit any - # Python wrapper overhead or risk falling through to the native - # eager-primitive ``RMSNorm.forward_native`` path. - normed = torch.empty_like(x) - torch.ops._C.rms_norm(normed, x, norm_weight, eps) - logits = vllm_ops.dsv3_router_gemm(normed, gate_weight, torch.float32) - return normed, logits - - -def fused_norm_router_gemm( - x: torch.Tensor, - norm_weight: torch.Tensor, - gate_weight: torch.Tensor, - eps: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return vllm_ops.dsv4_norm_router_gemm(x, norm_weight, gate_weight, eps) - - -def _make_inputs(num_tokens: int, num_experts: int, hidden_size: int, seed: int = 0): - torch.manual_seed(seed) - device = "cuda" - x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) - norm_w = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) - gate_w = torch.randn(num_experts, hidden_size, dtype=torch.bfloat16, device=device) - # Down-scale gate_w so the GEMV output stays in a representable range. - gate_w = gate_w / float(hidden_size) ** 0.5 - norm_w = (norm_w * 0.1) + 1.0 - return x, norm_w, gate_w - - -def calculate_diff( - num_tokens: int, - num_experts: int, - hidden_size: int = HIDDEN_SIZE, - normed_atol: float = 2e-3, - logits_atol: float = 1e-2, - rtol: float = 1e-2, -) -> None: - x, norm_w, gate_w = _make_inputs(num_tokens, num_experts, hidden_size) - - normed_unfused, logits_unfused = unfused_norm_router_gemm( - x.clone(), norm_w, gate_w, RMS_EPS - ) - normed_fused, logits_fused = fused_norm_router_gemm( - x.clone(), norm_w, gate_w, RMS_EPS - ) - - def _max_abs(a, b): - return (a.float() - b.float()).abs().max().item() - - print(f"\n=== M={num_tokens} E={num_experts} H={hidden_size} ===") - print(f"normed_x |fused - unfused| = {_max_abs(normed_fused, normed_unfused):.3e}") - print(f"logits |fused - unfused| = {_max_abs(logits_fused, logits_unfused):.3e}") - - ok_normed = torch.allclose( - normed_fused.float(), - normed_unfused.float(), - atol=normed_atol, - rtol=rtol, - ) - ok_logits = torch.allclose( - logits_fused.float(), - logits_unfused.float(), - atol=logits_atol, - rtol=rtol, - ) - if ok_normed and ok_logits: - print( - f"OK fused vs unfused within " - f"normed_atol={normed_atol:.0e} logits_atol={logits_atol:.0e} " - f"rtol={rtol:.0e}" - ) - else: - print( - f"FAIL normed_ok={ok_normed} logits_ok={ok_logits}; " - f"see max-abs values above" - ) - - -def get_benchmark(): - # Only num_tokens varies (DSV4-Pro hard-codes E=384); single-axis - # sweep yields a clean line plot with M on the x-axis. - num_experts = NUM_EXPERTS_CHOICES[0] - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["num_tokens"], - x_vals=list(range(1, 17)), - line_arg="provider", - line_vals=["unfused", "fused"], - line_names=["unfused (rms+dsv3)", "fused (dsv4)"], - styles=[("green", "-"), ("red", "-")], - ylabel="us", - plot_name=f"norm-router-gemm-E{num_experts}-H{HIDDEN_SIZE}", - args={}, - ) - ) - def benchmark(num_tokens, provider): - x, norm_w, gate_w = _make_inputs(num_tokens, num_experts, HIDDEN_SIZE) - - quantiles = [0.5, 0.2, 0.8] - if provider == "unfused": - fn = lambda: unfused_norm_router_gemm( # noqa: E731 - x, norm_w, gate_w, RMS_EPS - ) - else: - fn = lambda: fused_norm_router_gemm( # noqa: E731 - x, norm_w, gate_w, RMS_EPS - ) - - ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--save-path", - type=str, - default="./configs/norm_router_gemm/", - ) - parser.add_argument( - "--skip-bench", - action="store_true", - help="Run only the correctness check, not the perf sweep.", - ) - args = parser.parse_args() - - # Correctness sweep over the full fast-path range M=1..16. - for m in range(1, 17): - for e in NUM_EXPERTS_CHOICES: - calculate_diff(num_tokens=m, num_experts=e, hidden_size=HIDDEN_SIZE) - - if args.skip_bench: - return - - benchmark = get_benchmark() - benchmark.run(print_data=True, save_path=args.save_path) - - -if __name__ == "__main__": - main() diff --git a/csrc/moe/dsv3_router_gemm_bf16_out.cu b/csrc/moe/dsv3_router_gemm_bf16_out.cu index 8c7000ccf352..b11ba991b26c 100644 --- a/csrc/moe/dsv3_router_gemm_bf16_out.cu +++ b/csrc/moe/dsv3_router_gemm_bf16_out.cu @@ -182,7 +182,7 @@ void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = 1; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( diff --git a/csrc/moe/dsv3_router_gemm_float_out.cu b/csrc/moe/dsv3_router_gemm_float_out.cu index 483eb1e023eb..2756cba0b14f 100644 --- a/csrc/moe/dsv3_router_gemm_float_out.cu +++ b/csrc/moe/dsv3_router_gemm_float_out.cu @@ -182,7 +182,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = 1; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( diff --git a/csrc/moe/dsv3_router_gemm_utils.h b/csrc/moe/dsv3_router_gemm_utils.h index 13b60d6be6a1..9b533bcabfcc 100644 --- a/csrc/moe/dsv3_router_gemm_utils.h +++ b/csrc/moe/dsv3_router_gemm_utils.h @@ -29,15 +29,3 @@ inline int getSMVersion() { auto* props = at::cuda::getCurrentDeviceProperties(); return props->major * 10 + props->minor; } - -inline bool getEnvEnablePDL() { - static std::once_flag flag; - static bool enablePDL = false; - std::call_once(flag, [&]() { - if (getSMVersion() >= 90) { - const char* env = std::getenv("TRTLLM_ENABLE_PDL"); - enablePDL = env && env[0] == '1' && env[1] == '\0'; - } - }); - return enablePDL; -} diff --git a/csrc/moe/dsv4_norm_router_gemm.h b/csrc/moe/dsv4_norm_router_gemm.h deleted file mode 100644 index 7f66bfcb4aed..000000000000 --- a/csrc/moe/dsv4_norm_router_gemm.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Fused RMSNorm + router GEMV for DeepSeek V4. - * - * Computes in a single kernel: - * normed_x[m,k] = x[m,k] * rsqrt(mean(x[m]^2) + eps) * norm_weight[k] - * router_logits[m,n] = sum_k(normed_x[m,k] * gate_weight[n,k]) - * - * The GEMV body mirrors the algorithm in csrc/moe/dsv3_router_gemm_*.cu - * (warp butterfly + smem cross-warp reduction, fp32 accumulation, PDL on - * SM90+). Blocks 0..kNumTokens-1 each materialize one token's normed_x - * row to global memory using the algebraic identity - * logits[m,n] = rsqrt[m] * sum_k(x[m,k] * nw[k] * gw[n,k]) - * which lets every block produce its column of logits before normed_x - * exists in gmem. - * - * Logits output is fp32 only — DeepSeek V4 router gate is hard-coded to - * fp32 (vllm/model_executor/models/deepseek_v4.py:749). - */ - -#pragma once - -#include -#include - -#include "dsv3_router_gemm_utils.h" - -template -void invokeNormRouterGemm(float* logits, __nv_bfloat16* normed_x, T const* x, - T const* norm_weight, T const* gate_weight, float eps, - cudaStream_t stream); diff --git a/csrc/moe/dsv4_norm_router_gemm_entry.cu b/csrc/moe/dsv4_norm_router_gemm_entry.cu deleted file mode 100644 index 1232248e6177..000000000000 --- a/csrc/moe/dsv4_norm_router_gemm_entry.cu +++ /dev/null @@ -1,130 +0,0 @@ -/* - * TORCH op entry for the fused RMSNorm + router GEMV kernel - * (DeepSeek V4 Pro). This op is DSV4-Pro-specific: the kernel is - * instantiated only for ``num_experts == 384`` and ``hidden_dim == - * 7168``. Other configurations (e.g. DSV4-Flash with H=4096) must - * fall back to the unfused ``rms_norm`` + ``dsv3_router_gemm`` path. - */ - -#include -#include -#include - -#include -#include - -#include "core/registration.h" -#include "dsv4_norm_router_gemm.h" - -namespace { - -// DSV4-Pro hard-coded shape constants. Renamed from the earlier -// ``kKimiK2NumExperts`` to avoid the misleading impression that this -// kernel targets Kimi K2 — 384 happens to match Kimi K2's gate but the -// intent here is DSV4-Pro. -constexpr int kDsv4NumExperts = 384; -constexpr int kDsv4HiddenDim = 7168; - -template -struct LoopUnroller { - static void unroll(int num_tokens, float* logits, __nv_bfloat16* normed_x, - __nv_bfloat16 const* x, __nv_bfloat16 const* norm_weight, - __nv_bfloat16 const* gate_weight, float eps, - cudaStream_t stream) { - if (num_tokens == kBegin) { - invokeNormRouterGemm<__nv_bfloat16, kBegin, kDsv4NumExperts, - kDsv4HiddenDim>(logits, normed_x, x, norm_weight, - gate_weight, eps, stream); - } else { - LoopUnroller::unroll(num_tokens, logits, normed_x, x, - norm_weight, gate_weight, eps, - stream); - } - } -}; - -template -struct LoopUnroller { - static void unroll(int num_tokens, float* logits, __nv_bfloat16* normed_x, - __nv_bfloat16 const* x, __nv_bfloat16 const* norm_weight, - __nv_bfloat16 const* gate_weight, float eps, - cudaStream_t stream) { - if (num_tokens == kEnd) { - invokeNormRouterGemm<__nv_bfloat16, kEnd, kDsv4NumExperts, - kDsv4HiddenDim>(logits, normed_x, x, norm_weight, - gate_weight, eps, stream); - } else { - throw std::invalid_argument( - "Invalid num_tokens, only supports 1 to 16 for " - "dsv4_norm_router_gemm"); - } - } -}; - -} // namespace - -void dsv4_norm_router_gemm(at::Tensor& logits, // [num_tokens, E] fp32 - at::Tensor& normed_x, // [num_tokens, H] bf16 - at::Tensor const& x, // [num_tokens, H] bf16 - at::Tensor const& norm_weight, // [H] bf16 - at::Tensor const& gate_weight, // [E, H] bf16 - double eps) { - TORCH_CHECK(x.dim() == 2 && norm_weight.dim() == 1 && gate_weight.dim() == 2, - "x must be 2D, norm_weight 1D, gate_weight 2D"); - TORCH_CHECK(logits.dim() == 2 && normed_x.dim() == 2, - "logits and normed_x must be 2D"); - - int const num_tokens = x.size(0); - int const hidden_dim = x.size(1); - int const num_experts = gate_weight.size(0); - - TORCH_CHECK(hidden_dim == kDsv4HiddenDim, - "Expected hidden_dim=", kDsv4HiddenDim, - " (DSV4-Pro), but got hidden_dim=", hidden_dim); - TORCH_CHECK(gate_weight.size(1) == hidden_dim, - "gate_weight.shape[1] must equal x.shape[1]"); - TORCH_CHECK(norm_weight.size(0) == hidden_dim, - "norm_weight.shape[0] must equal x.shape[1]"); - TORCH_CHECK(num_experts == kDsv4NumExperts, - "Expected num_experts=", kDsv4NumExperts, - " (DSV4-Pro), but got num_experts=", num_experts); - TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16, - "num_tokens must be in [1, 16] for dsv4_norm_router_gemm"); - - TORCH_CHECK(x.dtype() == at::kBFloat16, "x must be bf16"); - TORCH_CHECK(norm_weight.dtype() == at::kBFloat16, "norm_weight must be bf16"); - TORCH_CHECK(gate_weight.dtype() == at::kBFloat16, "gate_weight must be bf16"); - TORCH_CHECK(normed_x.dtype() == at::kBFloat16, "normed_x must be bf16"); - TORCH_CHECK(logits.dtype() == at::kFloat, - "logits must be float32 (DSV4 router output is hard-coded fp32)"); - - TORCH_CHECK(normed_x.size(0) == num_tokens && normed_x.size(1) == hidden_dim, - "normed_x must be [num_tokens, hidden_dim]"); - TORCH_CHECK(logits.size(0) == num_tokens && logits.size(1) == num_experts, - "logits must be [num_tokens, num_experts]"); - - TORCH_CHECK(x.is_contiguous() && norm_weight.is_contiguous() && - gate_weight.is_contiguous() && normed_x.is_contiguous() && - logits.is_contiguous(), - "all tensors must be contiguous"); - - auto const sm = getSMVersion(); - TORCH_CHECK(sm >= 90 && sm <= 103, - "dsv4_norm_router_gemm requires SM_90 <= CUDA ARCH <= SM_103"); - - cudaStream_t const stream = at::cuda::getCurrentCUDAStream(); - - auto* logits_ptr = reinterpret_cast(logits.mutable_data_ptr()); - auto* nx_ptr = reinterpret_cast<__nv_bfloat16*>(normed_x.mutable_data_ptr()); - auto* x_ptr = reinterpret_cast<__nv_bfloat16 const*>(x.data_ptr()); - auto* nw_ptr = reinterpret_cast<__nv_bfloat16 const*>(norm_weight.data_ptr()); - auto* gw_ptr = reinterpret_cast<__nv_bfloat16 const*>(gate_weight.data_ptr()); - float const eps_f = static_cast(eps); - - LoopUnroller<1, 16>::unroll(num_tokens, logits_ptr, nx_ptr, x_ptr, nw_ptr, - gw_ptr, eps_f, stream); -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("dsv4_norm_router_gemm", &dsv4_norm_router_gemm); -} diff --git a/csrc/moe/dsv4_norm_router_gemm_kernel.cu b/csrc/moe/dsv4_norm_router_gemm_kernel.cu deleted file mode 100644 index dc6e17f19b18..000000000000 --- a/csrc/moe/dsv4_norm_router_gemm_kernel.cu +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Fused RMSNorm + router GEMV for DeepSeek V4 (logits are fp32; bf16 - * output is unsupported because DSV4 hard-codes fp32 logits). See - * dsv4_norm_router_gemm.h for the math. - * - * The GEMV body mirrors csrc/moe/dsv3_router_gemm_float_out.cu (warp - * butterfly reduction + smem cross-warp reduction, fp32 accumulation, - * 128-thread block, PDL on SM90+). RMSNorm is folded into the same - * pass via the identity - * logits[m,n] = rsqrt[m] * sum_k(x[m,k] * nw[k] * gw[n,k]) - * so x is read exactly once per block during the GEMV phase. Blocks - * 0..kNumTokens-1 each materialize one row of normed_x for downstream - * experts / shared_experts to consume. - */ - -#include -#include - -#include -#include - -#include "dsv4_norm_router_gemm.h" - -namespace { - -// Convert 8 bf16 values packed in uint4 into 8 floats. Mirrors the helper -// in dsv3_router_gemm_float_out.cu (kept local so the dsv3 file stays -// untouched). -template -__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, - float* dst) { - __nv_bfloat16* bf16_ptr = - reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); -#pragma unroll - for (int i = 0; i < VPT; i++) { - dst[i] = __bfloat162float(bf16_ptr[i]); - } -} - -template -__global__ __launch_bounds__(128, 1) void norm_router_gemm_kernel( - float* __restrict__ logits, __nv_bfloat16* __restrict__ normed_x, - T const* __restrict__ x, T const* __restrict__ norm_weight, - T const* __restrict__ gate_weight, float eps) { - static_assert(kBlockSize == 128, "kernel assumes blockDim.x == 128"); - static_assert(kHiddenDim % (VPT * kBlockSize) == 0, - "kHiddenDim must be a multiple of VPT * kBlockSize"); - - int const n_idx = blockIdx.x; - int const tid = threadIdx.x; - constexpr int kWarpSize = 32; - constexpr int kNumWarps = kBlockSize / kWarpSize; - constexpr int k_elems_per_iter = VPT * kBlockSize; - constexpr int k_iterations = kHiddenDim / k_elems_per_iter; - - T const* gw_col = gate_weight + n_idx * kHiddenDim; - - // Per-thread accumulators — fp32 throughout, matching dsv3 / layernorm. - float partial[kNumTokens] = {}; - float ss[kNumTokens] = {}; - - // Cross-warp reduction scratch. - __shared__ float sm_partial[kNumTokens][kNumWarps]; - __shared__ float sm_ss[kNumTokens][kNumWarps]; - __shared__ float s_rsqrt[kNumTokens]; - - int k_bases[k_iterations]; -#pragma unroll - for (int ki = 0; ki < k_iterations; ki++) { - k_bases[ki] = ki * k_elems_per_iter + tid * VPT; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // ---- Phase 1: single pass over x, accumulate partial GEMV and ss. ---- -#pragma unroll - for (int ki = 0; ki < k_iterations; ki++) { - int const k_base = k_bases[ki]; - - uint4 nw_vec = *reinterpret_cast(norm_weight + k_base); - float nw_f[VPT]; - bf16_uint4_to_float8(nw_vec, nw_f); - - uint4 b_vec = *reinterpret_cast(gw_col + k_base); - float b_f[VPT]; - bf16_uint4_to_float8(b_vec, b_f); - -#pragma unroll - for (int m = 0; m < kNumTokens; m++) { - uint4 a_vec = - *reinterpret_cast(x + m * kHiddenDim + k_base); - float a_f[VPT]; - bf16_uint4_to_float8(a_vec, a_f); - -#pragma unroll - for (int k = 0; k < VPT; k++) { - float a = a_f[k]; - ss[m] += a * a; - partial[m] += a * nw_f[k] * b_f[k]; - } - } - } - - // ---- Phase 2: warp butterfly reduction for both ss[] and partial[]. ---- - int const warpId = tid / kWarpSize; - int const laneId = tid % kWarpSize; - -#pragma unroll - for (int m = 0; m < kNumTokens; m++) { - float p = partial[m]; - float s = ss[m]; - - p += __shfl_xor_sync(0xffffffff, p, 16); - s += __shfl_xor_sync(0xffffffff, s, 16); - p += __shfl_xor_sync(0xffffffff, p, 8); - s += __shfl_xor_sync(0xffffffff, s, 8); - p += __shfl_xor_sync(0xffffffff, p, 4); - s += __shfl_xor_sync(0xffffffff, s, 4); - p += __shfl_xor_sync(0xffffffff, p, 2); - s += __shfl_xor_sync(0xffffffff, s, 2); - p += __shfl_xor_sync(0xffffffff, p, 1); - s += __shfl_xor_sync(0xffffffff, s, 1); - - if (laneId == 0) { - sm_partial[m][warpId] = p; - sm_ss[m][warpId] = s; - } - } - - __syncthreads(); - - // ---- Phase 3: tid 0 finalises the reduction, writes logits, stashes - // rsqrt[m] in smem for phase 4. ---- - if (tid == 0) { -#pragma unroll - for (int m = 0; m < kNumTokens; m++) { - float p_sum = 0.0f; - float s_sum = 0.0f; -#pragma unroll - for (int w = 0; w < kNumWarps; w++) { - p_sum += sm_partial[m][w]; - s_sum += sm_ss[m][w]; - } - // Order matches layernorm_kernels.cu: rsqrtf(variance / H + eps). - // Use division (not multiply-by-reciprocal) to avoid an extra ULP - // mismatch with the reference RMSNorm. - float rs = rsqrtf(s_sum / static_cast(kHiddenDim) + eps); - s_rsqrt[m] = rs; - logits[m * kNumExperts + n_idx] = p_sum * rs; - } - } - - __syncthreads(); - - // ---- Phase 4: spread normed_x writes across blocks 0..kNumTokens-1. - // Each writer block handles exactly one token row, - // avoiding the long tail of block 0 doing all M rows. - // Every block has every token's rsqrt[] in s_rsqrt - // already (computed independently in phase 3), so no - // cross-block synchronization is required. ---- - if (n_idx < kNumTokens) { - int const m_writer = n_idx; - float const rs = s_rsqrt[m_writer]; - __nv_bfloat16 const* x_row = x + m_writer * kHiddenDim; - __nv_bfloat16* normed_row = normed_x + m_writer * kHiddenDim; - -#pragma unroll - for (int ki = 0; ki < k_iterations; ki++) { - int const k_base = k_bases[ki]; - - uint4 nw_vec = *reinterpret_cast(norm_weight + k_base); - float nw_f[VPT]; - bf16_uint4_to_float8(nw_vec, nw_f); - - uint4 a_vec = *reinterpret_cast(x_row + k_base); - float a_f[VPT]; - bf16_uint4_to_float8(a_vec, a_f); - - uint4 normed_vec; - __nv_bfloat16* np = reinterpret_cast<__nv_bfloat16*>(&normed_vec); -#pragma unroll - for (int k = 0; k < VPT; k++) { - np[k] = __float2bfloat16(a_f[k] * rs * nw_f[k]); - } - *reinterpret_cast(normed_row + k_base) = normed_vec; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -} // namespace - -template -void invokeNormRouterGemm(float* logits, __nv_bfloat16* normed_x, T const* x, - T const* norm_weight, T const* gate_weight, float eps, - cudaStream_t stream) { - constexpr int VPT = 16 / sizeof(T); - constexpr int kBlockSize = 128; - - cudaLaunchConfig_t config; - config.gridDim = kNumExperts; - config.blockDim = kBlockSize; - config.dynamicSmemBytes = 0; - config.stream = stream; - - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = 1; - config.numAttrs = 1; - config.attrs = attrs; - - cudaLaunchKernelEx(&config, - norm_router_gemm_kernel, - logits, normed_x, x, norm_weight, gate_weight, eps); -} - -// Template instantiations — DSV4-Pro is the only supported configuration: -// num_experts=384, hidden_dim=7168. Other shapes (e.g. DSV4-Flash with -// hidden_dim=4096) fall back to the unfused path on the Python side. -#define INSTANTIATE(M) \ - template void invokeNormRouterGemm<__nv_bfloat16, M, 384, 7168>( \ - float*, __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, \ - __nv_bfloat16 const*, float, cudaStream_t); - -INSTANTIATE(1) -INSTANTIATE(2) -INSTANTIATE(3) -INSTANTIATE(4) -INSTANTIATE(5) -INSTANTIATE(6) -INSTANTIATE(7) -INSTANTIATE(8) -INSTANTIATE(9) -INSTANTIATE(10) -INSTANTIATE(11) -INSTANTIATE(12) -INSTANTIATE(13) -INSTANTIATE(14) -INSTANTIATE(15) -INSTANTIATE(16) - -#undef INSTANTIATE diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index f8c1f3d92263..ac0e8d59f604 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -75,12 +75,4 @@ void shuffle_rows(const torch::Tensor& input_tensor, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); - -// Fused RMSNorm + router GEMV for DeepSeek V4. Produces both: -// normed_x[m,k] = x[m,k] * rsqrt(mean(x[m]^2) + eps) * norm_weight[k] -// logits[m,n] = sum_k(normed_x[m,k] * gate_weight[n,k]) -// in a single kernel launch. Same dim/dtype constraints as dsv3_router_gemm. -void dsv4_norm_router_gemm(at::Tensor& logits, at::Tensor& normed_x, - at::Tensor const& x, at::Tensor const& norm_weight, - at::Tensor const& gate_weight, double eps); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index b8b91e3c3818..b8145435cb1d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -125,12 +125,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file - - // DeepSeek V4 fused RMSNorm + router GEMV for SM90+ - m.def( - "dsv4_norm_router_gemm(Tensor! logits, Tensor! normed_x, Tensor x, " - "Tensor norm_weight, Tensor gate_weight, float eps) -> ()"); - // conditionally compiled so impl registration is in source file #endif } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e816d1aaab89..84f944df2bf8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2367,36 +2367,6 @@ def dsv3_router_gemm( return output -def dsv4_norm_router_gemm( - x: torch.Tensor, - norm_weight: torch.Tensor, - gate_weight: torch.Tensor, - eps: float, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused RMSNorm + router GEMV for DeepSeek V4. - - Returns ``(normed_x, router_logits)`` where - normed_x[m,k] = x[m,k] * rsqrt(mean(x[m]^2) + eps) * norm_weight[k] - router_logits[m,n] = sum_k(normed_x[m,k] * gate_weight[n,k]) - - DSV4-specific constraints (caller must check before dispatching here): - - x, norm_weight, gate_weight all bf16 contiguous - - x.shape == [num_tokens, 7168] with num_tokens in [1, 16] - - gate_weight.shape == [num_experts, 7168] with num_experts in {256, 384} - - SM 9.x or 10.x device - - Logits output is fp32 (hard-coded by DSV4 router). - """ - num_tokens, hidden = x.shape - num_experts = gate_weight.shape[0] - normed_x = torch.empty_like(x) - logits = torch.empty(num_tokens, num_experts, device=x.device, dtype=torch.float32) - torch.ops._moe_C.dsv4_norm_router_gemm( - logits, normed_x, x, norm_weight, gate_weight, float(eps) - ) - return normed_x, logits - - def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index a4d05ef245c7..c242fef2d026 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -170,7 +170,7 @@ def _mhc_pre_tilelang_fake( sinkhorn_repeat: int, n_splits: int = 1, norm_weight: torch.Tensor | None = None, - norm_eps: float = 0.0, + norm_eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hc_mult = residual.shape[-2] hidden_size = residual.shape[-1] @@ -238,7 +238,7 @@ def mhc_fused_post_pre_tilelang( n_splits: int = 1, tile_n: int = 1, norm_weight: torch.Tensor | None = None, - norm_eps: float = 0.0, + norm_eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Run one MHC post block followed by the next MHC pre block. @@ -450,7 +450,7 @@ def _mhc_fused_post_pre_tilelang_fake( n_splits: int = 1, tile_n: int = 1, norm_weight: torch.Tensor | None = None, - norm_eps: float = 0.0, + norm_eps: float = 1e-6, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hc_mult = residual.shape[-2] hidden_size = residual.shape[-1] diff --git a/vllm/model_executor/layers/fused_moe/router/norm_gate_linear.py b/vllm/model_executor/layers/fused_moe/router/norm_gate_linear.py deleted file mode 100644 index 50f1ef7a0efe..000000000000 --- a/vllm/model_executor/layers/fused_moe/router/norm_gate_linear.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Fused RMSNorm + GateLinear for DeepSeek V4 MoE routing.""" - -import torch -from torch import nn - -import vllm._custom_ops as ops -from vllm.model_executor.custom_op import PluggableLayer -from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils.torch_utils import direct_register_custom_op - -DSV4_PRO_NUM_EXPERTS = 384 -DSV4_PRO_HIDDEN_SIZE = 7168 -DSV4_PRO_MAX_NUM_TOKENS = 16 - - -def _dsv4_pro_norm_gate( - x: torch.Tensor, - norm_weight: torch.Tensor, - gate_weight: torch.Tensor, - rms_eps: float, -) -> tuple[torch.Tensor, torch.Tensor]: - """Runtime dispatcher: fused ``dsv4_norm_router_gemm`` (M<=16) vs the - unfused ``rms_norm + dsv3_router_gemm`` fallback (M>16). - - """ - if x.shape[0] <= DSV4_PRO_MAX_NUM_TOKENS: - return ops.dsv4_norm_router_gemm(x, norm_weight, gate_weight, rms_eps) - - normed = torch.empty_like(x) - # Call `_C::rms_norm` here to avoid select the path of native rms - torch.ops._C.rms_norm(normed, x, norm_weight, rms_eps) - logits = torch.mm(normed, gate_weight.t(), out_dtype=torch.float32) - return normed, logits - - -def _dsv4_pro_norm_gate_fake( - x: torch.Tensor, - norm_weight: torch.Tensor, - gate_weight: torch.Tensor, - rms_eps: float, -) -> tuple[torch.Tensor, torch.Tensor]: - num_tokens = x.shape[0] - num_experts = gate_weight.shape[0] - return ( - torch.empty_like(x), - torch.empty(num_tokens, num_experts, dtype=torch.float32, device=x.device), - ) - - -direct_register_custom_op( - op_name="dsv4_pro_norm_gate", - op_func=_dsv4_pro_norm_gate, - mutates_args=[], - fake_impl=_dsv4_pro_norm_gate_fake, -) - - -@PluggableLayer.register("norm_gated_linear") -class NormGateLinear(nn.Module): - """RMSNorm + GateLinear, fused on DSV4-Pro only.""" - - def __init__( - self, - hidden_size: int, - num_experts: int, - rms_eps: float = 1e-6, - params_dtype: torch.dtype | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.num_experts = num_experts - self.rms_eps = rms_eps - - self.norm = RMSNorm(hidden_size, eps=rms_eps, dtype=params_dtype) - self.gate = GateLinear( - hidden_size, - num_experts, - bias=False, - out_dtype=torch.float32, # DSV4 router output is fp32 - params_dtype=params_dtype, - prefix=f"{prefix}.gate" if prefix else "gate", - ) - - self.e_score_correction_bias = None - self.tid2eid = None - - self._fused_kernel_supported = ( - hidden_size == DSV4_PRO_HIDDEN_SIZE - and num_experts == DSV4_PRO_NUM_EXPERTS - and self.gate.allow_dsv3_router_gemm # cuda platform - ) - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - if self._fused_kernel_supported: - assert x.shape[1] == DSV4_PRO_HIDDEN_SIZE - assert self.gate.weight.shape == ( - DSV4_PRO_NUM_EXPERTS, - DSV4_PRO_HIDDEN_SIZE, - ) - # This must be wrapped in a custom op because our torch.compile integration - # does not support runtime dispatching on num_tokens. - return torch.ops.vllm.dsv4_pro_norm_gate( - x, self.norm.weight, self.gate.weight, self.rms_eps - ) - - # Non-Pro fallback (e.g. DSV4-Flash with hidden_size=4096): - - normed_x = self.norm(x) - logits, _ = self.gate(normed_x) - return normed_x, logits diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index d69bad8d38de..1540667d1a4b 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -18,13 +18,10 @@ ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( fused_topk_bias, ) -from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import ( - NormGateLinear, -) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -655,23 +652,23 @@ def __init__( "deep_gemm_mega_moe for this checkpoint." ) - # Fused RMSNorm + gate: owns both ffn_norm and the gate matmul. - self.norm_gate = NormGateLinear( - hidden_size=config.hidden_size, - num_experts=config.n_routed_experts, - rms_eps=config.rms_norm_eps, - prefix=f"{prefix}.norm_gate", + self.gate = GateLinear( + input_size=config.hidden_size, + output_size=config.n_routed_experts, + bias=False, + out_dtype=torch.float32, + prefix=f"{prefix}.gate", ) - # Routing-side tensors live on ``norm_gate`` directly (not on the - # inner gate); they are initialized to None in NormGatedLinear and - # populated below depending on the MoE variant. + + self.gate.e_score_correction_bias = None + self.gate.tid2eid = None is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 if is_hash_moe: # hash MoE doesn't use e_score_correction_bias # Use randint instead of empty to avoid garbage values causing # invalid memory access in dummy mode (--load-format="dummy") - self.norm_gate.tid2eid = nn.Parameter( + self.gate.tid2eid = nn.Parameter( torch.randint( 0, config.n_routed_experts, @@ -681,7 +678,7 @@ def __init__( requires_grad=False, ) elif getattr(config, "topk_method", None) == "noaux_tc": - self.norm_gate.e_score_correction_bias = nn.Parameter( + self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32), requires_grad=False, ) @@ -744,9 +741,10 @@ def _init_fused_moe_experts( self.n_local_experts = config.n_routed_experts // self.tp_size self.experts_start_idx = self.tp_rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts - # We don't pass `gate` into FusedMoE + self.experts = FusedMoE( shared_experts=self.shared_experts, + gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -756,8 +754,8 @@ def _init_fused_moe_experts( prefix=f"{prefix}.experts", scoring_func=self.scoring_func, routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.norm_gate.e_score_correction_bias, - hash_indices_table=self.norm_gate.tid2eid, + e_score_correction_bias=self.gate.e_score_correction_bias, + hash_indices_table=self.gate.tid2eid, swiglu_limit=self.swiglu_limit, router_logits_dtype=torch.float32, ) @@ -765,40 +763,40 @@ def _init_fused_moe_experts( def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: - if self.norm_gate.tid2eid is not None and input_ids is None: + if self.gate.tid2eid is not None and input_ids is None: raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") if not self.use_mega_moe: return self._forward_fused_moe(hidden_states, input_ids) org_shape = hidden_states.shape - normed_x, router_logits = self.norm_gate(hidden_states) + router_logits, _ = self.gate(hidden_states) topk_weights, topk_ids = fused_topk_bias( - hidden_states=normed_x, + hidden_states=hidden_states, gating_output=router_logits, scoring_func=self.scoring_func, - e_score_correction_bias=self.norm_gate.e_score_correction_bias.data - if self.norm_gate.e_score_correction_bias is not None + e_score_correction_bias=self.gate.e_score_correction_bias.data + if self.gate.e_score_correction_bias is not None else None, topk=self.n_activated_experts, renormalize=self.renormalize, indices_type=self.hash_indices_dtype, input_tokens=input_ids, - hash_indices_table=self.norm_gate.tid2eid, + hash_indices_table=self.gate.tid2eid, routed_scaling_factor=self.routed_scaling_factor, ) activation_clamp = ( float(self.swiglu_limit) if self.swiglu_limit is not None else None ) final_hidden_states = self.experts( - normed_x, + hidden_states, topk_weights, topk_ids, activation_clamp=activation_clamp, ) if self.shared_experts is not None: - shared_output = self.shared_experts(normed_x) + shared_output = self.shared_experts(hidden_states) final_hidden_states += shared_output return final_hidden_states.view(org_shape) @@ -806,14 +804,21 @@ def forward( def _forward_fused_moe( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: - assert not self.experts.is_internal_router org_shape = hidden_states.shape - normed_x, router_logits = self.norm_gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=normed_x, - router_logits=router_logits, - input_ids=input_ids, - ) + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=hidden_states, + input_ids=input_ids, + ) + else: + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + input_ids=input_ids, + ) return final_hidden_states.view(org_shape) @@ -1017,8 +1022,7 @@ def __init__( self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) - # ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the - # router gate matmul); see ``NormGatedLinear``. + self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) self.hc_mult = config.hc_mult self.hc_sinkhorn_iters = config.hc_sinkhorn_iters self.hc_eps = config.hc_eps @@ -1148,8 +1152,7 @@ def _forward_cuda( self.hc_post_alpha, self.hc_sinkhorn_iters, ) - # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes - # the pre-norm activation directly. + x = self.ffn_norm(x) x = self.ffn(x, input_ids) return x, residual, post_mix, res_mix @@ -1176,8 +1179,7 @@ def _forward_rocm( x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) - # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes - # the pre-norm activation directly. + x = self.ffn_norm(x) x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) return x, None, None, None @@ -1527,13 +1529,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: orig_to_new_suffix={ "head.weight": "lm_head.weight", "embed.weight": "embed_tokens.weight", - # Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate`` - # (see NormGatedLinear). - ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", - ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", - ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", - # Hash MoE table also moved off the inner gate. - ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", + ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ ".attn.compressor.": ".attn.mla_attn.compressor.", diff --git a/vllm/models/deepseek_v4/amd/mtp.py b/vllm/models/deepseek_v4/amd/mtp.py index 071abe2f4a49..bcdd76de4c29 100644 --- a/vllm/models/deepseek_v4/amd/mtp.py +++ b/vllm/models/deepseek_v4/amd/mtp.py @@ -292,11 +292,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ".emb.tok_emb.weight": ".embed_tokens.weight", ".head.weight": ".shared_head.head.weight", ".norm.weight": ".shared_head.norm.weight", - # Pre-MoE norm + gate are now owned by - # ``DeepseekV4MoE.norm_gate`` (see NormGatedLinear). - ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", - ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", - ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", } def _remap_weight_name(name: str) -> str: @@ -444,12 +439,7 @@ def _find_mtp_layer_idx(name: str) -> int: ".shared_experts.w2", ".shared_experts.down_proj" ) if name.endswith(".ffn.gate.bias"): - # ``e_score_correction_bias`` lives on - # ``norm_gate`` directly (not on the inner gate). - name = name.replace( - ".ffn.gate.bias", - ".ffn.norm_gate.e_score_correction_bias", - ) + name = name.replace(".bias", ".e_score_correction_bias") param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 6c4f058cfb1e..44c97715d848 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -1181,8 +1181,7 @@ def _forward_cuda( norm_weight=ffn_norm_weight, norm_eps=ffn_norm_eps, ) - # ffn_norm is fused into mhc_fused_post_pre above; ffn() takes the - # already-normed activation directly. + x = self.ffn(x, input_ids) return x, residual, post_mix, res_mix