Skip to content

[ROCm] Cast score correction bias tensor during model construction for DeepSeek/Kimi-K2#39999

Merged
tjtanaa merged 8 commits intovllm-project:mainfrom
heachary:heachary/kk2/fix-score-correction-bias-typecast
Apr 24, 2026
Merged

[ROCm] Cast score correction bias tensor during model construction for DeepSeek/Kimi-K2#39999
tjtanaa merged 8 commits intovllm-project:mainfrom
heachary:heachary/kk2/fix-score-correction-bias-typecast

Conversation

@heachary
Copy link
Copy Markdown
Contributor

@heachary heachary commented Apr 16, 2026

Purpose

The moe score correction bias tensor was being cast to the gate output dtype on every forward pass. The datatype that this tensor needs to be cast to is known at model construction time and never changes beyond that. So this repeated cast is redundant work that launches an extra GPU kernel per MoE layer per forward call.

This PR moves the cast to the model construction thereby eliminating the per-forward-pass overhead.

Summary

Before

image

After

image

The trace shows the elementwise operation kernel responsible for this typecast operation happening before the grouped-topk operation in every forward pass. With the following changes this typecast is shifted to model construction thereby avoiding the call during every forward pass:

  • vllm/model_executor/models/deepseek_v2.py: Pre-cast e_score_correction_bias to match gate.out_dtype during DeepseekV2MoE construction. Since all downstream consumers (FusedMoE, router) share the same nn.Parameter object, this single mutation propagates everywhere.
  • vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py: Replace the runtime .to() cast with an assert that the bias dtype already matches the gating output dtype, catching any future regression where the init-time cast is missed.
  • vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py: Same change — replace the runtime .to() cast with a matching assert.

Test Result

Accuracy

Benchmark Metric Score Threshold Status
GSM8K (5-shot, 250 samples) exact_match (flexible-extract) 0.936 0.90 ✅ PASS
GSM8K (5-shot, 250 samples) exact_match (strict-match) 0.936 0.90 ✅ PASS

Performance

config conc Baseline type-cast fix Improvement
1k1k 4 190 194 1.021053
1k1k 8 321 323 1.006231
1k1k 16 521 528 1.013436
1k1k 32 783 799 1.020434
1k1k 64 1173 1180 1.005968
1k8k 4 108 110 1.018519
1k8k 8 188 191 1.015957
1k8k 16 307 309 1.006515
1k8k 32 485 491 1.012371
1k8k 64 730 734 1.005479
8k1k 4 799 816 1.021277
8k1k 8 1348 1365 1.012611
8k1k 16 2051 2064 1.006338
8k1k 32 2917 2949 1.01097
8k1k 64 4038 4012 0.993561
Geomean 1.011355

Test Plan

  • Accuracy test
  • Performance benchmark

Signed-off-by: Hemanth Acharya <heachary@amd.com>
@mergify mergify Bot added deepseek Related to DeepSeek models rocm Related to AMD ROCm labels Apr 16, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 16, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the efficiency of the DeepSeek-V2 model by pre-casting the e_score_correction_bias during the initialization phase, which avoids repeated type conversions during each forward pass. Additionally, it adds assertions to the ROCm Aiter fused MoE layers to verify that the bias and gating output types match. Feedback highlights a potential issue where direct mutation of the parameter's data and the use of a static type attribute could cause the new assertions to fail if the model is cast to a different precision after initialization.

Comment thread vllm/model_executor/models/deepseek_v2.py
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary force-pushed the heachary/kk2/fix-score-correction-bias-typecast branch from 0330c2b to 34dd717 Compare April 17, 2026 11:25
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary marked this pull request as ready for review April 18, 2026 10:00
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 19, 2026
@tjtanaa tjtanaa enabled auto-merge (squash) April 19, 2026 14:19
Comment thread vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py Outdated
Comment thread vllm/model_executor/models/deepseek_v2.py Outdated
Signed-off-by: Hemanth Acharya <heachary@amd.com>
auto-merge was automatically disabled April 21, 2026 08:27

Head branch was pushed to by a user without write access

Signed-off-by: Hemanth Acharya <heachary@amd.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Apr 21, 2026

@bnellnm could you take a final look?

Comment on lines +354 to +367
# Pre-cast the bias to match the gate output dtype so the
# conversion is not repeated on every forward pass. All
# downstream references (FusedMoE, router) share the same
# nn.Parameter object, so mutating .data propagates everywhere.
# Weight loading uses copy_(), which handles the dtype conversion.
# Only needed on ROCm where the aiter biased_grouped_topk kernel
# requires the bias dtype to match the gating output dtype.
if (
self.is_rocm_aiter_moe_enabled
and self.gate.e_score_correction_bias is not None
):
self.gate.e_score_correction_bias.data = (
self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
)
Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block of code could live in fused_moe/layer.py (with any additional appropriate checks, e.g. routing type)

Copy link
Copy Markdown
Contributor Author

@heachary heachary Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned already in my previous comment why thats a harder change that i decided to skip. Let me elaborate with some details here:

Moving the bias pre-cast (lines 354-367) into FusedMoE.init() isn't standalone — it depends on gate.set_out_dtype() which is called just above it, and that call relies on self.experts.quant_method.is_monolithic and self.experts.routing_method_type — both only available after FusedMoE.init() completes. So both blocks (set_out_dtype() and the new bias dtype cast) would need to move together to the end of FusedMoE.init().

The concern is that this becomes more invasive: every model passing gate= to FusedMoE — including qwen3_moe, qwen3_next, step3p5, and AXK1 — would now have set_out_dtype called automatically in FusedMoE.init(), which changes their gate output dtype behavior even though they don't currently call set_out_dtype at all.

If this is not a big concern, I would like to leave this section as is to minimise the impact.

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I think the initial casting code could probably live in layer.py since it seems like it is generally applicable to ROCm MoE and particular routing methods.

@heachary
Copy link
Copy Markdown
Contributor Author

@tjtanaa theres a failing unit test, but from the looks of it i think its disconnected to the changes in this PR. let me know if i should investigate further.

@tjtanaa tjtanaa merged commit fa4b705 into vllm-project:main Apr 24, 2026
72 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 24, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…r DeepSeek/Kimi-K2 (vllm-project#39999)

Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…r DeepSeek/Kimi-K2 (vllm-project#39999)

Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Adrian <info@zzit.ch>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…r DeepSeek/Kimi-K2 (vllm-project#39999)

Signed-off-by: Hemanth Acharya <heachary@amd.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants