Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)

if e_score_correction_bias is not None:
if e_score_correction_bias.dtype != gating_output.dtype:
e_score_correction_bias = e_score_correction_bias.to(gating_output.dtype)
Comment thread
heachary marked this conversation as resolved.
Outdated
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,13 @@ def fused_topk_bias(
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
if e_score_correction_bias.dtype != gating_output.dtype:
e_score_correction_bias = e_score_correction_bias.to(
gating_output.dtype
)
Comment thread
heachary marked this conversation as resolved.
Outdated
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group=num_expert_group,
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ def __init__(
else torch.bfloat16
)

# Pre-cast the bias to match the gate output dtype so the
# conversion is not repeated on every forward pass. All
# downstream references (FusedMoE, router) share the same
# nn.Parameter object, so mutating .data propagates everywhere.
# Weight loading uses copy_(), which handles the dtype conversion.
if self.gate.e_score_correction_bias is not None:
self.gate.e_score_correction_bias.data = (
self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
)
Comment thread
heachary marked this conversation as resolved.
Comment thread
heachary marked this conversation as resolved.
Outdated
Comment on lines +351 to +364

@bnellnm bnellnm Apr 21, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block of code could live in fused_moe/layer.py (with any additional appropriate checks, e.g. routing type)

@heachary heachary Apr 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned already in my previous comment why thats a harder change that i decided to skip. Let me elaborate with some details here:

Moving the bias pre-cast (lines 354-367) into FusedMoE.init() isn't standalone — it depends on gate.set_out_dtype() which is called just above it, and that call relies on self.experts.quant_method.is_monolithic and self.experts.routing_method_type — both only available after FusedMoE.init() completes. So both blocks (set_out_dtype() and the new bias dtype cast) would need to move together to the end of FusedMoE.init().

The concern is that this becomes more invasive: every model passing gate= to FusedMoE — including qwen3_moe, qwen3_next, step3p5, and AXK1 — would now have set_out_dtype called automatically in FusedMoE.init(), which changes their gate output dtype behavior even though they don't currently call set_out_dtype at all.

If this is not a big concern, I would like to leave this section as is to minimise the impact.


def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
Expand Down
Loading