diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cd28fb0192f3..b3359333c9fc 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -351,6 +351,20 @@ def __init__( else torch.bfloat16 ) + # Align e_score_correction_bias dtype with the gate's output dtype so + # downstream routing kernels (e.g. aiter biased_grouped_topk) don't + # have to cast this constant parameter on every forward pass. We + # mutate `.data` in place to preserve the Parameter identity already + # captured by `self.experts` / the router. The weight loader uses + # `param.data.copy_(loaded_weight)`, which converts the loaded fp32 + # checkpoint tensor into this dtype automatically at load time. + if self.gate.e_score_correction_bias is not None: + target_dtype = self.gate.out_dtype + if self.gate.e_score_correction_bias.dtype != target_dtype: + self.gate.e_score_correction_bias.data = ( + self.gate.e_score_correction_bias.data.to(target_dtype) + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim)