Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,15 @@ def __init__(
self.is_fusion_moe_shared_experts_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
if (
self.is_rocm_aiter_moe_enabled
and self.gate.e_score_correction_bias is not None
):
# AITER biased_grouped_topk requires the correction bias dtype to
# match the router logits. Keep DeepSeek's correction bias in fp32
# by requesting fp32 router logits for this routing path.
self.gate.set_out_dtype(torch.float32)

if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
self.shared_experts = None
else:
Expand Down Expand Up @@ -338,22 +347,15 @@ def __init__(
n_shared_experts=config.n_shared_experts
if self.is_fusion_moe_shared_experts_enabled
else None,
router_logits_dtype=self.gate.out_dtype,
)

# Pre-cast the bias to match the gate output dtype so the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave this comment here?

# conversion is not repeated on every forward pass. All
# downstream references (FusedMoE, router) share the same
# nn.Parameter object, so mutating .data propagates everywhere.
# Weight loading uses copy_(), which handles the dtype conversion.
# Only needed on ROCm where the aiter biased_grouped_topk kernel
# requires the bias dtype to match the gating output dtype.
if (
self.is_rocm_aiter_moe_enabled
and self.gate.e_score_correction_bias is not None
):
gate_out_dtype = self.gate.out_dtype or self.gate.weight.dtype
self.gate.e_score_correction_bias.data = (
self.gate.e_score_correction_bias.data.to(gate_out_dtype)
self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class AiterMLAHelper:
"""

_AITER_MIN_MLA_HEADS: Final = 16
_AITER_UNSUPPORTED_HEADS = [32]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're certain head size 32 works, we can remove this check/var entirely. cc @ganyi1996ppo since this check was added in #41217

_AITER_UNSUPPORTED_HEADS: ClassVar[tuple[int, ...]] = ()

@staticmethod
def check_num_heads_validity(num_heads: int):
Expand Down
Loading