diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 26a1903fa7b9..86881376a106 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -299,6 +299,15 @@ def __init__( self.is_fusion_moe_shared_experts_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() ) + if ( + self.is_rocm_aiter_moe_enabled + and self.gate.e_score_correction_bias is not None + ): + # AITER biased_grouped_topk requires the correction bias dtype to + # match the router logits. Keep DeepSeek's correction bias in fp32 + # by requesting fp32 router logits for this routing path. + self.gate.set_out_dtype(torch.float32) + if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: self.shared_experts = None else: @@ -338,22 +347,15 @@ def __init__( n_shared_experts=config.n_shared_experts if self.is_fusion_moe_shared_experts_enabled else None, + router_logits_dtype=self.gate.out_dtype, ) - # Pre-cast the bias to match the gate output dtype so the - # conversion is not repeated on every forward pass. All - # downstream references (FusedMoE, router) share the same - # nn.Parameter object, so mutating .data propagates everywhere. - # Weight loading uses copy_(), which handles the dtype conversion. - # Only needed on ROCm where the aiter biased_grouped_topk kernel - # requires the bias dtype to match the gating output dtype. if ( self.is_rocm_aiter_moe_enabled and self.gate.e_score_correction_bias is not None ): - gate_out_dtype = self.gate.out_dtype or self.gate.weight.dtype self.gate.e_score_correction_bias.data = ( - self.gate.e_score_correction_bias.data.to(gate_out_dtype) + self.gate.e_score_correction_bias.data.to(self.gate.out_dtype) ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 2106226118ef..121194b14460 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -396,7 +396,7 @@ class AiterMLAHelper: """ _AITER_MIN_MLA_HEADS: Final = 16 - _AITER_UNSUPPORTED_HEADS = [32] + _AITER_UNSUPPORTED_HEADS: ClassVar[tuple[int, ...]] = () @staticmethod def check_num_heads_validity(num_heads: int):