Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm_gaudi/ops/hpu_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,14 @@ def apply(
topk_weights = topk_weights.to(x.dtype)
topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)
if not layer.use_grouped_topk:
topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype conversions for topk_ids and topk_weights are duplicated - they appear both before line 163 (lines 159-160) and within this conditional block (lines 164-165). When use_grouped_topk is False, these conversions happen twice unnecessarily. Consider moving the earlier conversions (lines 159-160) into an else block, or removing the duplicate logic.

Suggested change
topk_weights = topk_weights.to(x.dtype)

Copilot uses AI. Check for mistakes.

output = layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
)
Expand Down
81 changes: 79 additions & 2 deletions vllm_gaudi/ops/hpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import vllm
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod)
from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp)

Expand Down Expand Up @@ -63,11 +64,14 @@ def forward_oot(
topk_weights = topk_weights.to(x.dtype)
topk_ids = topk_ids.view(*x.shape[:-1], -1)
topk_weights = topk_weights.view(*x.shape[:-1], -1)
if not layer.use_grouped_topk:
topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
Comment on lines +67 to +69
Copy link

Copilot AI Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype conversions for topk_ids and topk_weights are now duplicated - they appear both before line 67 (lines 63-64) and within this conditional block (lines 68-69). When use_grouped_topk is False, these conversions happen twice unnecessarily. Consider moving the earlier conversions (lines 63-64) into an else block, or removing the duplicate logic.

Copilot uses AI. Check for mistakes.

return layer.moe_op(
x,
topk_ids.to(torch.int64),
topk_weights.to(x.dtype),
topk_ids,
topk_weights,
permuted_weights=True,
activation=activation,
).view(*input_shape)
Expand Down Expand Up @@ -153,7 +157,80 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
return ", ".join(mappings)


def patched_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:

gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()

if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")

# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()

num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1)
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
group_scores.add_(top1_val)
else:
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]

if num_token > 1024:
group_mask = torch.zeros_like(group_scores)
for i in range(topk_group):
_, group_idx = torch.max(group_scores, dim=-1)
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
if i < topk_group - 1:
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
else:
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]

tmp_scores = scores.reshape(num_token, num_expert_group, -1) + (
(1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1)
tmp_scores = tmp_scores.reshape(num_token, -1)

if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64)


# Apply patches
FusedMoE.forward = patched_fused_moe_forward
vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = \
get_compressed_expert_map
vllm.model_executor.layers.fused_moe.layer.grouped_topk = \
patched_grouped_topk