Fix Topk Calculation in GPTOSS#970
Conversation
There was a problem hiding this comment.
Pull request overview
Adjusts MoE router top-k selection to handle GPTOSS (gpt_oss) routing correctly by selecting experts from logits before applying softmax.
Changes:
- Adds a
gpt_oss-specific routing path that appliestopkon logits and then softmaxes the selected logits. - Keeps existing behavior for other model types (softmax over all experts, then
topk, then renormalize).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if self.model_type is not None and self.model_type in ["gpt_oss"]: | ||
| topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) | ||
| topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) | ||
| else: | ||
| topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) | ||
| topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) | ||
| topk_weights /= topk_weights.sum(dim=-1, keepdim=True) |
There was a problem hiding this comment.
The new gpt_oss branch uses softmax(..., dim=-1) while the existing path uses softmax(..., dim=1). This introduces inconsistent normalization across model types and can produce incorrect routing if router_logits is not strictly 2D (or if the expert dimension isn’t 1). Consider using the same expert-dimension for both branches (typically dim=-1), or deriving the expert dimension once and reusing it for softmax/topk.
| if self.model_type is not None and self.model_type in ["gpt_oss"]: | |
| topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) | |
| topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) | |
| else: | |
| topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) | |
| topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) | |
| topk_weights /= topk_weights.sum(dim=-1, keepdim=True) | |
| expert_dim = router_logits.dim() - 1 | |
| if self.model_type is not None and self.model_type in ["gpt_oss"]: | |
| topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=expert_dim) | |
| topk_weights = F.softmax(topk_weights, dim=expert_dim, dtype=torch.float32) | |
| else: | |
| topk_weights = F.softmax(router_logits, dim=expert_dim, dtype=torch.float32) | |
| topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=expert_dim) | |
| topk_weights /= topk_weights.sum(dim=expert_dim, keepdim=True) |
There was a problem hiding this comment.
I know. Just changing dim=-1 in softmax would resolve the issue. But for some model it might be required to apply softmax on second dimension, so I did not change it.
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
@SKRohit do we need to merge this PR? |
|
@iboiko-habana i am running few tests will let you once changes are ready for merge. |
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
|
@iboiko-habana I have verified the changes. PR is ready to merge from my side. |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
a9c538f to
65ccd95
Compare
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
65ccd95 to
2d2838c
Compare
Signed-off-by: Rohit kumar Singh <rksingh@habana.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Rohit Kumar Singh <9626333+SKRohit@users.noreply.github.com> Signed-off-by: Rohit kumar Singh <rksingh@habana.ai>
2d2838c to
40c3d7f
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
@iboiko-habana Can we merge? |
Fixes Accuracy Issue in GPTOSS: vllm-project#887. Updates `apply_monolithic` introduced in vllm-project#876 to handle gptoss --------- Signed-off-by: Rohit kumar Singh <rksingh@habana.ai> Signed-off-by: Rohit Kumar Singh <9626333+SKRohit@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Fixes Accuracy Issue in GPTOSS: vllm-project#887. Updates `apply_monolithic` introduced in vllm-project#876 to handle gptoss --------- Signed-off-by: Rohit kumar Singh <rksingh@habana.ai> Signed-off-by: Rohit Kumar Singh <9626333+SKRohit@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Hi @iboiko-habana @SKRohit , we encountered this issue (#891 (comment)) earlier. However, with the new branch, we are again seeing a drop in accuracy when using the Unsloth version of GPTOSS.
|
Fixes Accuracy Issue in GPTOSS: #887. Updates `apply_monolithic` introduced in #876 to handle gptoss --------- Signed-off-by: Rohit kumar Singh <rksingh@habana.ai> Signed-off-by: Rohit Kumar Singh <9626333+SKRohit@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

Fixes Accuracy Issue in GPTOSS: #887. Updates
apply_monolithicintroduced in #876 to handle gptoss