[Perf][MoE][ROCm][Kimi-K2.5] Remove a redundant per-forward-pass dtype conversion of the routing bias parameter in DeepSeek-V2/V3 MoE#40341
Closed
xaguilar-amd wants to merge 1 commit intovllm-project:mainfrom
Conversation
…ias parameter Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the DeepseekV2 model implementation to align the e_score_correction_bias data type with the gate's output data type during initialization. This change optimizes performance by avoiding redundant type casting in the routing kernels during the forward pass. I have no feedback to provide.
5 tasks
Contributor
|
Is this a duplicate of #39999? (already merged) |
Contributor
Author
|
yes, I missed that one, sorry! Closing this now. Thanks! |
Contributor
|
Turns out I was wrong ... this was broken in main, see #41405 which will fix it again |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Fixes a redundant per-forward-pass dtype conversion of the routing bias parameter
(
e_score_correction_bias) in the DeepSeek-V2 / DeepSeek-V3 / Kimi-K2 family of MoE models.Root cause
In
vllm/model_executor/models/deepseek_v2.py, the routing bias is allocated intorch.float32:However the router GEMM (
GateLinear) producesgating_outputin a dtype chosen later viaself.gate.set_out_dtype(...):float32for monolithic + DeepSeekV3 routingbfloat16for all other paths (DeepSeekV2, Kimi-K2.5, Mistral Large 3, etc.)On the ROCm aiter fused-MoE path,
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.pythen forces the bias to match the gating output dtype every call:
Because
e_score_correction_biasis a small, constant parameter that never changes after weightloading, doing this conversion once — at construction time — is sufficient. Today it is instead
launched as a separate kernel between the gate GEMM and the
biased_grouped_topkkernel on every MoE layer, on every decode step.Impact in a Perfetto trace
Profiled on Kimi-K2.5-MXFP4 (ROCm):
Before
After
Why fix it in
DeepseekV2MoErather than inrocm_aiter_fused_moe.pyThe
.to(gating_output.dtype)inrocm_aiter_fused_moe.pyis required for correctness on its own— the aiter kernel reinterprets
correction_biasas the same dtype asgating_output(see
VLLM_DISPATCH_FLOATING_TYPES(gating_output.scalar_type(), ...)in the aiterbiased_grouped_topklauncher). Removing it would break callers that pass a mismatched-dtype bias.The right fix is to make sure the bias already has the correct dtype before the kernel is called,
by aligning it with
gate.out_dtypeat construction time. When the dtypes already match,tensor.to(same_dtype)is a cheap no-op in PyTorch (returns the same tensor, no kernel launch),so this PR is compatible with both CUDA and ROCm paths.
Changes
Single-file change in
vllm/model_executor/models/deepseek_v2.py, insideDeepseekV2MoE.__init__, immediately after the existingself.gate.set_out_dtype(...)call:Why do the cast in-place on
.datainstead of replacing thenn.Parameterself.gate.e_score_correction_biasis passed by reference intoSharedFusedMoE(line ~334 of the same file), which stores it on the router
(
GroupedTopKRouter.e_score_correction_bias,FusedMoE.e_score_correction_bias, etc.).Reassigning
self.gate.e_score_correction_bias = new_parameterwould leave those consumers holdingthe old fp32 Parameter. Mutating
.datain place keeps thenn.Parameterobject identity, so allexisting references see the new dtype.
Why not just change the
dtype=torch.float32on the allocation lineThe correct target depends on the quant-method choice made later in
__init__:float32→ bias must befloat32bfloat16→ bias must bebfloat16At the point the bias is allocated,
self.expertsdoes not exist yet, sois_monolithicandrouting_method_typeare not yet known. Aligning afterset_out_dtypeis called is the simplestcorrect option and keeps the fix local to one file.
Weight-loading compatibility
default_weight_loaderinvllm/model_executor/model_loader/weight_utils.pyperformsparam.data.copy_(loaded_weight), which does automatic dtype conversion from the checkpoint'sfloat32tensor into whatever dtype the parameter now has. No checkpoint compatibility impact.Scope
This PR intentionally only touches
deepseek_v2.py. Several other model files have a similarfp32 allocation of
e_score_correction_bias(e.g.
glm4_moe.py,dots1.py,openpangu.py,ernie45_moe.py,lfm2_moe.py,nemotron_h.py,minimax_m2.py,sarvam.py,exaone_moe.py,ernie45_vl_moe.py,AXK1.py,mimo_v2_flash.py,param2moe.py,longcat_flash.py), but they are outside the DeepSeek-V2/V3/Kimi-K2 path actuallyexercised by this change.
Test Result
Accuracy
Accuracy is preserved (no drift, identical to within eval-level noise).
Performance
Risks / Compatibility
biased_grouped_topkkernel already operates on agating_output.dtypeview of the bias — we simply move the dtype conversion from "every iteration" to "once at load
time". The weight loader does the same
fp32 → target_dtypeconversion it would otherwise do,once.
gate.out_dtype == torch.float32, so the new branch is a no-op(bias stays fp32). Behavior identical to today.
grouped_topkpath(
vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py:125): the bias is used inscores + e_score_correction_bias.unsqueeze(0). Whengate.out_dtype == bf16, both sides becomebf16, which is consistent with the dtype of
scores(derived fromgating_output). No dtypepromotion is introduced.
ops.grouped_topkpath (guarded bycurrent_platform.is_cuda()+VLLM_USE_FUSED_MOE_GROUPED_TOPK): unchanged — this PR only flips which dtype the bias tensorcarries; the fused kernel receives the bias tensor directly as it does today.
default_weight_loader'sparam.data.copy_()converts during load.torch.compile/ CUDAGraph capture: unaffected — the dtype is fixed before the firstforward pass.