-
Notifications
You must be signed in to change notification settings - Fork 128
Patch Grouped Topk #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Patch Grouped Topk #708
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import torch | ||
| import vllm | ||
| from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant | ||
| from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) | ||
| from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) | ||
|
|
||
|
|
@@ -63,11 +64,14 @@ def forward_oot( | |
| topk_weights = topk_weights.to(x.dtype) | ||
| topk_ids = topk_ids.view(*x.shape[:-1], -1) | ||
| topk_weights = topk_weights.view(*x.shape[:-1], -1) | ||
| if not layer.use_grouped_topk: | ||
| topk_ids = topk_ids.to(torch.int64) | ||
| topk_weights = topk_weights.to(x.dtype) | ||
|
Comment on lines
+67
to
+69
|
||
|
|
||
| return layer.moe_op( | ||
| x, | ||
| topk_ids.to(torch.int64), | ||
| topk_weights.to(x.dtype), | ||
| topk_ids, | ||
| topk_weights, | ||
| permuted_weights=True, | ||
| activation=activation, | ||
| ).view(*input_shape) | ||
|
|
@@ -153,7 +157,80 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: | |
| return ", ".join(mappings) | ||
|
|
||
|
|
||
| def patched_grouped_topk( | ||
| hidden_states: torch.Tensor, | ||
| gating_output: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool, | ||
| num_expert_group: int = 0, | ||
| topk_group: int = 0, | ||
| scoring_func: str = "softmax", | ||
| routed_scaling_factor: float = 1.0, | ||
| e_score_correction_bias: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
|
|
||
| gating_output = gating_output.float() | ||
| if e_score_correction_bias is not None: | ||
| e_score_correction_bias = e_score_correction_bias.float() | ||
|
|
||
| if scoring_func == "softmax": | ||
| scores = torch.softmax(gating_output, dim=-1) | ||
| elif scoring_func == "sigmoid": | ||
| scores = gating_output.sigmoid() | ||
| else: | ||
| raise ValueError(f"Unsupported scoring function: {scoring_func}") | ||
|
|
||
| # For batch invariance, use sorted=True to ensure deterministic expert selection | ||
| use_sorted = vllm_is_batch_invariant() | ||
|
|
||
| num_token = scores.size(0) | ||
| if e_score_correction_bias is not None: | ||
| # Store original scores before applying correction bias. We use biased | ||
| # scores for expert selection but original scores for routing weights | ||
| original_scores = scores | ||
| scores = scores + e_score_correction_bias.unsqueeze(0) | ||
| scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1) | ||
| top1_val, top1_idx = torch.max(scores_tmp, dim=-1) | ||
| scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | ||
| group_scores, top2_idx = torch.max(scores_tmp, dim=-1) | ||
| group_scores.add_(top1_val) | ||
| else: | ||
| group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] | ||
|
|
||
| if num_token > 1024: | ||
| group_mask = torch.zeros_like(group_scores) | ||
| for i in range(topk_group): | ||
| _, group_idx = torch.max(group_scores, dim=-1) | ||
| group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) | ||
| if i < topk_group - 1: | ||
| group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) | ||
| else: | ||
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] | ||
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | ||
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] | ||
|
|
||
| tmp_scores = scores.reshape(num_token, num_expert_group, -1) + ( | ||
| (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) | ||
| tmp_scores = tmp_scores.reshape(num_token, -1) | ||
|
|
||
| if e_score_correction_bias is not None: | ||
| topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] | ||
| # Use original unbiased scores for the routing weights | ||
| topk_weights = original_scores.gather(1, topk_ids) | ||
| else: | ||
| topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted) | ||
|
|
||
| if renormalize: | ||
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
|
|
||
| if routed_scaling_factor != 1.0: | ||
| topk_weights = topk_weights * routed_scaling_factor | ||
| return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64) | ||
|
|
||
|
|
||
| # Apply patches | ||
| FusedMoE.forward = patched_fused_moe_forward | ||
| vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = \ | ||
| get_compressed_expert_map | ||
| vllm.model_executor.layers.fused_moe.layer.grouped_topk = \ | ||
| patched_grouped_topk | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype conversions for
topk_idsandtopk_weightsare duplicated - they appear both before line 163 (lines 159-160) and within this conditional block (lines 164-165). Whenuse_grouped_topkis False, these conversions happen twice unnecessarily. Consider moving the earlier conversions (lines 159-160) into an else block, or removing the duplicate logic.