Revert "[perf][spec decoding] Skip full-vocab softmax in EAGLE draft when topk == 1 (#26235)"#26358
Conversation
…when topk == 1 (sgl-project#26235)" This reverts commit a77449f. ## Why After sgl-project#26235 landed (2026-05-25 09:06 UTC), nightly DeepSeek-V3.2 + MTP GSM8K accuracy collapsed to ~0.04 with ~96% invalid output, breaking both the rocm720 nightly and the manual scout run on this same SHA with a known-good aiter pin (so the regression is sglang-side, not aiter). Signature observed (verbatim from logs): metrics={'accuracy': 0.035, 'invalid': 0.96, 'latency': 43.7, 'output_throughput': 2283} AssertionError: 0.035 not greater than 0.94 ## Evidence | Run | sglang | aiter | DSv3.2-MTP gsm8k | |---|---|---|---:| | Last green 2026-05-24 | 7f45bcd (before sgl-project#26235) | 32e1e6d7 (default) | PASS | | nightly 2026-05-25 (rocm720) | b13d3d1 (contains sgl-project#26235) | 32e1e6d7 (default) | 0.035 / invalid=0.96 FAIL | | manual scout 2026-05-25 (this commit + good aiter) | a77449f (= sgl-project#26235) | d7caa3d2 (good baseline) | 0.035 / invalid=0.96 FAIL | The manual scout proves the failure reproduces on a77449f even with the previously-known-good aiter override, so the cause is sglang-side within the 7f45bcd..a77449f window. sgl-project#26235 is the only EAGLE-draft- touching commit in that window, and the failure mode (correct accept- length, invalid text) matches a draft-token-selection bug. References: - nightly job: https://github.com/sgl-project/sglang/actions/runs/26413553204/job/77753010360 - manual scout: https://github.com/sgl-project/sglang/actions/runs/26392698141/job/77803129088 - last green: https://github.com/sgl-project/sglang/actions/runs/26368356160 (2026-05-24) ## Path forward The perf optimization in sgl-project#26235 is valid for non-MTP EAGLE paths. Re-land with an additional guard (e.g. \`topk == 1 AND not is_mtp_path\`) so it skips the MTP path, which appears to depend on the full-vocab softmax that was dropped for downstream draft selection. cc @Qiaolin-Yu (author) — happy to coordinate on the re-land patch.
There was a problem hiding this comment.
Code Review
This pull request simplifies the draft extend and forward logic by removing the special-case handling for self.topk == 1 and instead applying a full-vocab softmax and fast_topk across all cases. The reviewer suggests optimizing the self.topk == 1 case using a mathematical identity to compute the top-1 probability directly, which avoids the memory bandwidth overhead of a full-vocab softmax and fast_topk without sacrificing accuracy.
| probs = torch.softmax(ret.next_token_logits, dim=-1) | ||
| ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) |
There was a problem hiding this comment.
Instead of reverting completely to the full-vocab softmax when self.topk == 1, we can mathematically optimize the top-1 probability computation. By using the identity topk_p and topk_index without materializing the full softmax tensor in HBM. This avoids the accuracy collapse caused by hardcoding topk_p = 1.0 while retaining the performance benefits of skipping the full-vocab softmax.
| probs = torch.softmax(ret.next_token_logits, dim=-1) | |
| ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) | |
| if self.topk == 1: | |
| max_logits, ret.topk_index = torch.max(ret.next_token_logits, dim=-1, keepdim=True) | |
| ret.topk_p = 1.0 / torch.exp(ret.next_token_logits - max_logits).sum(dim=-1, keepdim=True) | |
| else: | |
| probs = torch.softmax(ret.next_token_logits, dim=-1) | |
| ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) |
| probs = torch.softmax(logits_output.next_token_logits, dim=-1) | ||
| topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) |
There was a problem hiding this comment.
We can optimize the top-1 probability computation here as well when self.topk == 1 by using the mathematical identity fast_topk without causing any accuracy degradation.
if self.topk == 1:
max_logits, topk_index = torch.max(logits_output.next_token_logits, dim=-1, keepdim=True)
topk_p = 1.0 / torch.exp(logits_output.next_token_logits - max_logits).sum(dim=-1, keepdim=True)
else:
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)| probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) | ||
| ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) |
There was a problem hiding this comment.
Apply the same top-1 softmax optimization here to avoid full-vocab softmax and fast_topk overhead when self.topk == 1.
| probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) | |
| ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) | |
| if self.topk == 1: | |
| max_logits, ret_topk_index = torch.max(draft_logits_output.next_token_logits, dim=-1, keepdim=True) | |
| ret_topk_p = 1.0 / torch.exp(draft_logits_output.next_token_logits - max_logits).sum(dim=-1, keepdim=True) | |
| else: | |
| probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) | |
| ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) |
|
/tag-and-rerun-ci |
Motivation
Reverts #26235 (commit a77449f, merged 2026-05-25 09:06 UTC) because it caused DeepSeek-V3.2 + MTP GSM8K accuracy to collapse to ~0.04 with ~96% invalid output on the very next nightly run.
Verification on the revert branch — CONFIRMED FIX on ROCm 7.2
Re-ran
nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720against this PR's revert branch (run 26438872740):This is a clean 3.5 pp above threshold — a full recovery of MTP GSM8K accuracy on ROCm 7.2 mi35x. Compare to the failing pre-revert score: 0.035 / invalid=0.96 on the same hardware and same Dockerfile-default aiter (
32e1e6d7).Evidence (verified verbatim from logs)
Failure signature on the rocm720 nightly:
Cross-run comparison:
7f45bcdd2a(before #26235)32e1e6d7(default)b13d3d18c(contains #26235)32e1e6d7(default)a77449f86(= #26235)d7caa3d2(good baseline)9a7258680(this PR HEAD)32e1e6d7(default)The combination of (a) the manual scout reproducing R108 on
a77449f86with known-good aiter and (b) the revert-branch verify scoring 0.975 on rocm720 with default aiter proves the cause is on the sglang side, specifically #26235.References
Modifications
This is a clean
git revertofa77449f86.Files reverted:
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py(+10/-2 → 0)python/sglang/srt/speculative/eagle_worker_v2.py(+22/-4 → 0)Path forward
The perf optimization in #26235 is valid for the non-MTP EAGLE path. Suggest re-landing with an additional guard such as
topk == 1 AND not is_mtp_path(or equivalent) so it skips the MTP path, which appears to depend on the full-vocab softmax for downstream draft selection.cc @Qiaolin-Yu (author of #26235) — happy to coordinate on the re-land patch once a guard is identified.
Checklist
CI States
Latest PR Test (Base): ⏳ Run #26443970447
Latest PR Test (Extra): ❌ Run #26443970366