diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index c30cb7012..1c72f3b41 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -160,10 +160,14 @@ def apply( topk_weights = topk_weights.to(x.dtype) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) + if not layer.use_grouped_topk: + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + output = layer.moe_op( x, - topk_ids.to(torch.int64), - topk_weights.to(x.dtype), + topk_ids, + topk_weights, permuted_weights=True, activation=activation, ) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 2f0a107f5..158af825e 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -2,6 +2,7 @@ import torch import vllm +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) @@ -63,11 +64,14 @@ def forward_oot( topk_weights = topk_weights.to(x.dtype) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) + if not layer.use_grouped_topk: + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) return layer.moe_op( x, - topk_ids.to(torch.int64), - topk_weights.to(x.dtype), + topk_ids, + topk_weights, permuted_weights=True, activation=activation, ).view(*input_shape) @@ -153,7 +157,80 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: return ", ".join(mappings) +def patched_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1) + top1_val, top1_idx = torch.max(scores_tmp, dim=-1) + scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + group_scores, top2_idx = torch.max(scores_tmp, dim=-1) + group_scores.add_(top1_val) + else: + group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] + + if num_token > 1024: + group_mask = torch.zeros_like(group_scores) + for i in range(topk_group): + _, group_idx = torch.max(group_scores, dim=-1) + group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) + if i < topk_group - 1: + group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + else: + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + + tmp_scores = scores.reshape(num_token, num_expert_group, -1) + ( + (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) + tmp_scores = tmp_scores.reshape(num_token, -1) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64) + + # Apply patches FusedMoE.forward = patched_fused_moe_forward vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = \ get_compressed_expert_map +vllm.model_executor.layers.fused_moe.layer.grouped_topk = \ + patched_grouped_topk