Skip to content

[Perf][MoE][ROCm][Kimi-K2.5] Remove a redundant per-forward-pass dtype conversion of the routing bias parameter in DeepSeek-V2/V3 MoE#40341

Closed
xaguilar-amd wants to merge 1 commit intovllm-project:mainfrom
xaguilar-amd:remove-bias-elementwise
Closed

[Perf][MoE][ROCm][Kimi-K2.5] Remove a redundant per-forward-pass dtype conversion of the routing bias parameter in DeepSeek-V2/V3 MoE#40341
xaguilar-amd wants to merge 1 commit intovllm-project:mainfrom
xaguilar-amd:remove-bias-elementwise

Conversation

@xaguilar-amd
Copy link
Copy Markdown
Contributor

@xaguilar-amd xaguilar-amd commented Apr 20, 2026

Purpose

Fixes a redundant per-forward-pass dtype conversion of the routing bias parameter
(e_score_correction_bias) in the DeepSeek-V2 / DeepSeek-V3 / Kimi-K2 family of MoE models.

Root cause

In vllm/model_executor/models/deepseek_v2.py, the routing bias is allocated in torch.float32:

self.gate.e_score_correction_bias = nn.Parameter(
    torch.empty(config.n_routed_experts, dtype=torch.float32)
)

However the router GEMM (GateLinear) produces gating_output in a dtype chosen later via
self.gate.set_out_dtype(...):

  • float32 for monolithic + DeepSeekV3 routing
  • bfloat16 for all other paths (DeepSeekV2, Kimi-K2.5, Mistral Large 3, etc.)

On the ROCm aiter fused-MoE path, vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
then forces the bias to match the gating output dtype every call:

# rocm_aiter_fused_moe.py:152-162
if e_score_correction_bias is not None:
    rocm_aiter_ops.biased_grouped_topk(
        gating_output,
        e_score_correction_bias.to(gating_output.dtype),  # <-- launches fp32 -> bf16 copy every step
        topk_weights, topk_ids,
        ...
    )

Because e_score_correction_bias is a small, constant parameter that never changes after weight
loading, doing this conversion once — at construction time — is sufficient. Today it is instead
launched as a separate kernel between the gate GEMM and the
biased_grouped_topk kernel on every MoE layer, on every decode step.

Impact in a Perfetto trace

Profiled on Kimi-K2.5-MXFP4 (ROCm):

Before

elementwise_topk

After

no_elementwise_topk

Why fix it in DeepseekV2MoE rather than in rocm_aiter_fused_moe.py

The .to(gating_output.dtype) in rocm_aiter_fused_moe.py is required for correctness on its own
— the aiter kernel reinterprets correction_bias as the same dtype as gating_output
(see VLLM_DISPATCH_FLOATING_TYPES(gating_output.scalar_type(), ...) in the aiter
biased_grouped_topk launcher). Removing it would break callers that pass a mismatched-dtype bias.

The right fix is to make sure the bias already has the correct dtype before the kernel is called,
by aligning it with gate.out_dtype at construction time. When the dtypes already match,
tensor.to(same_dtype) is a cheap no-op in PyTorch (returns the same tensor, no kernel launch),
so this PR is compatible with both CUDA and ROCm paths.


Changes

Single-file change in vllm/model_executor/models/deepseek_v2.py, inside
DeepseekV2MoE.__init__, immediately after the existing self.gate.set_out_dtype(...) call:

self.gate.set_out_dtype(
    torch.float32
    if self.experts.quant_method.is_monolithic
    and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3
    else torch.bfloat16
)

# Align e_score_correction_bias dtype with the gate's output dtype so
# downstream routing kernels (e.g. aiter biased_grouped_topk) don't
# have to cast this constant parameter on every forward pass. We
# mutate `.data` in place to preserve the Parameter identity already
# captured by `self.experts` / the router. The weight loader uses
# `param.data.copy_(loaded_weight)`, which converts the loaded fp32
# checkpoint tensor into this dtype automatically at load time.
if self.gate.e_score_correction_bias is not None:
    target_dtype = self.gate.out_dtype
    if self.gate.e_score_correction_bias.dtype != target_dtype:
        self.gate.e_score_correction_bias.data = (
            self.gate.e_score_correction_bias.data.to(target_dtype)
        )

Why do the cast in-place on .data instead of replacing the nn.Parameter

self.gate.e_score_correction_bias is passed by reference into SharedFusedMoE
(line ~334 of the same file), which stores it on the router
(GroupedTopKRouter.e_score_correction_bias, FusedMoE.e_score_correction_bias, etc.).
Reassigning self.gate.e_score_correction_bias = new_parameter would leave those consumers holding
the old fp32 Parameter. Mutating .data in place keeps the nn.Parameter object identity, so all
existing references see the new dtype.

Why not just change the dtype=torch.float32 on the allocation line

The correct target depends on the quant-method choice made later in __init__:

  • monolithic + DeepSeekV3 routing → gate outputs float32 → bias must be float32
  • everything else (including Kimi-K2.5 MXFP4 + aiter) → gate outputs bfloat16 → bias must be bfloat16

At the point the bias is allocated, self.experts does not exist yet, so is_monolithic and
routing_method_type are not yet known. Aligning after set_out_dtype is called is the simplest
correct option and keeps the fix local to one file.

Weight-loading compatibility

default_weight_loader in vllm/model_executor/model_loader/weight_utils.py performs
param.data.copy_(loaded_weight), which does automatic dtype conversion from the checkpoint's
float32 tensor into whatever dtype the parameter now has. No checkpoint compatibility impact.

Scope

This PR intentionally only touches deepseek_v2.py. Several other model files have a similar
fp32 allocation of e_score_correction_bias
(e.g. glm4_moe.py, dots1.py, openpangu.py, ernie45_moe.py, lfm2_moe.py, nemotron_h.py,
minimax_m2.py, sarvam.py, exaone_moe.py, ernie45_vl_moe.py, AXK1.py, mimo_v2_flash.py,
param2moe.py, longcat_flash.py), but they are outside the DeepSeek-V2/V3/Kimi-K2 path actually
exercised by this change.


Test Result

Accuracy

Accuracy is preserved (no drift, identical to within eval-level noise).

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9265 ± 0.0072
strict-match 5 exact_match 0.9265 ± 0.0072

Performance

  • Removes ~5 µs of serialized work on the routing critical path per MoE layer per decode step.

Risks / Compatibility

  • Numerical: the aiter biased_grouped_topk kernel already operates on a gating_output.dtype
    view of the bias — we simply move the dtype conversion from "every iteration" to "once at load
    time". The weight loader does the same fp32 → target_dtype conversion it would otherwise do,
    once.
  • Monolithic + DeepSeekV3 path: gate.out_dtype == torch.float32, so the new branch is a no-op
    (bias stays fp32). Behavior identical to today.
  • Non-ROCm / native grouped_topk path
    (vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py:125): the bias is used in
    scores + e_score_correction_bias.unsqueeze(0). When gate.out_dtype == bf16, both sides become
    bf16, which is consistent with the dtype of scores (derived from gating_output). No dtype
    promotion is introduced.
  • CUDA fused ops.grouped_topk path (guarded by current_platform.is_cuda() +
    VLLM_USE_FUSED_MOE_GROUPED_TOPK): unchanged — this PR only flips which dtype the bias tensor
    carries; the fused kernel receives the bias tensor directly as it does today.
  • Checkpoint compatibility: unchanged — HF checkpoints still ship the bias as fp32;
    default_weight_loader's param.data.copy_() converts during load.
  • torch.compile / CUDAGraph capture: unaffected — the dtype is fixed before the first
    forward pass.

…ias parameter

Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
@mergify mergify Bot added deepseek Related to DeepSeek models rocm Related to AMD ROCm labels Apr 20, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the DeepseekV2 model implementation to align the e_score_correction_bias data type with the gate's output data type during initialization. This change optimizes performance by avoiding redundant type casting in the routing kernels during the forward pass. I have no feedback to provide.

@xaguilar-amd xaguilar-amd marked this pull request as ready for review April 20, 2026 09:40
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@Rohan138
Copy link
Copy Markdown
Contributor

Is this a duplicate of #39999? (already merged)

@xaguilar-amd
Copy link
Copy Markdown
Contributor Author

yes, I missed that one, sorry! Closing this now. Thanks!

@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 30, 2026
@Rohan138
Copy link
Copy Markdown
Contributor

Turns out I was wrong ... this was broken in main, see #41405 which will fix it again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants