Skip to content

Revert "[perf][spec decoding] Skip full-vocab softmax in EAGLE draft when topk == 1 (#26235)"#26358

Merged
Qiaolin-Yu merged 2 commits into
sgl-project:mainfrom
michaelzhang-ai:revert-26235-eagle-topk1-softmax-skip
May 26, 2026
Merged

Revert "[perf][spec decoding] Skip full-vocab softmax in EAGLE draft when topk == 1 (#26235)"#26358
Qiaolin-Yu merged 2 commits into
sgl-project:mainfrom
michaelzhang-ai:revert-26235-eagle-topk1-softmax-skip

Conversation

@michaelzhang-ai
Copy link
Copy Markdown
Collaborator

@michaelzhang-ai michaelzhang-ai commented May 26, 2026

Motivation

Reverts #26235 (commit a77449f, merged 2026-05-25 09:06 UTC) because it caused DeepSeek-V3.2 + MTP GSM8K accuracy to collapse to ~0.04 with ~96% invalid output on the very next nightly run.

Verification on the revert branch — CONFIRMED FIX on ROCm 7.2

Re-ran nightly-accuracy-8-gpu-mi35x-deepseek-v32-mtp-rocm720 against this PR's revert branch (run 26438872740):

Accuracy: 0.975  (threshold 0.94, was 0.035 before revert)
elapsed=2533s, status=success

This is a clean 3.5 pp above threshold — a full recovery of MTP GSM8K accuracy on ROCm 7.2 mi35x. Compare to the failing pre-revert score: 0.035 / invalid=0.96 on the same hardware and same Dockerfile-default aiter (32e1e6d7).

Evidence (verified verbatim from logs)

Failure signature on the rocm720 nightly:

metrics={'accuracy': 0.035, 'invalid': 0.96,
         'latency': 43.7, 'output_throughput': 2283}
AssertionError: 0.035 not greater than 0.94

Cross-run comparison:

Run sglang aiter DSv3.2-MTP gsm8k rocm720
Last green 2026-05-24 7f45bcdd2a (before #26235) 32e1e6d7 (default) PASS
Nightly 2026-05-25 (rocm720) b13d3d18c (contains #26235) 32e1e6d7 (default) 0.035 / invalid=0.96 FAIL
Manual scout 2026-05-25 (this commit + known-good aiter, rocm700) a77449f86 (= #26235) d7caa3d2 (good baseline) 0.035 / invalid=0.96 FAIL
Revert PR verify 2026-05-26 (rocm720) 9a7258680 (this PR HEAD) 32e1e6d7 (default) 0.975 PASS

The combination of (a) the manual scout reproducing R108 on a77449f86 with known-good aiter and (b) the revert-branch verify scoring 0.975 on rocm720 with default aiter proves the cause is on the sglang side, specifically #26235.

References

Modifications

This is a clean git revert of a77449f86.

Files reverted:

  • python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py (+10/-2 → 0)
  • python/sglang/srt/speculative/eagle_worker_v2.py (+22/-4 → 0)

Path forward

The perf optimization in #26235 is valid for the non-MTP EAGLE path. Suggest re-landing with an additional guard such as topk == 1 AND not is_mtp_path (or equivalent) so it skips the MTP path, which appears to depend on the full-vocab softmax for downstream draft selection.

cc @Qiaolin-Yu (author of #26235) — happy to coordinate on the re-land patch once a guard is identified.

Checklist


CI States

Latest PR Test (Base): ⏳ Run #26443970447
Latest PR Test (Extra): ❌ Run #26443970366

…when topk == 1 (sgl-project#26235)"

This reverts commit a77449f.

## Why

After sgl-project#26235 landed (2026-05-25 09:06 UTC), nightly DeepSeek-V3.2 + MTP
GSM8K accuracy collapsed to ~0.04 with ~96% invalid output, breaking
both the rocm720 nightly and the manual scout run on this same SHA with
a known-good aiter pin (so the regression is sglang-side, not aiter).

Signature observed (verbatim from logs):

  metrics={'accuracy': 0.035, 'invalid': 0.96,
           'latency': 43.7, 'output_throughput': 2283}
  AssertionError: 0.035 not greater than 0.94

## Evidence

| Run | sglang | aiter | DSv3.2-MTP gsm8k |
|---|---|---|---:|
| Last green 2026-05-24 | 7f45bcd (before sgl-project#26235) | 32e1e6d7 (default) | PASS |
| nightly 2026-05-25 (rocm720) | b13d3d1 (contains sgl-project#26235) | 32e1e6d7 (default) | 0.035 / invalid=0.96 FAIL |
| manual scout 2026-05-25 (this commit + good aiter) | a77449f (= sgl-project#26235) | d7caa3d2 (good baseline) | 0.035 / invalid=0.96 FAIL |

The manual scout proves the failure reproduces on a77449f even with
the previously-known-good aiter override, so the cause is sglang-side
within the 7f45bcd..a77449f window. sgl-project#26235 is the only EAGLE-draft-
touching commit in that window, and the failure mode (correct accept-
length, invalid text) matches a draft-token-selection bug.

References:
- nightly job: https://github.com/sgl-project/sglang/actions/runs/26413553204/job/77753010360
- manual scout: https://github.com/sgl-project/sglang/actions/runs/26392698141/job/77803129088
- last green: https://github.com/sgl-project/sglang/actions/runs/26368356160 (2026-05-24)

## Path forward

The perf optimization in sgl-project#26235 is valid for non-MTP EAGLE paths.
Re-land with an additional guard (e.g. \`topk == 1 AND not is_mtp_path\`)
so it skips the MTP path, which appears to depend on the full-vocab
softmax that was dropped for downstream draft selection.

cc @Qiaolin-Yu (author) — happy to coordinate on the re-land patch.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the draft extend and forward logic by removing the special-case handling for self.topk == 1 and instead applying a full-vocab softmax and fast_topk across all cases. The reviewer suggests optimizing the self.topk == 1 case using a mathematical identity to compute the top-1 probability directly, which avoids the memory bandwidth overhead of a full-vocab softmax and fast_topk without sacrificing accuracy.

Comment on lines +404 to +405
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of reverting completely to the full-vocab softmax when self.topk == 1, we can mathematically optimize the top-1 probability computation. By using the identity $\text{softmax}(x)_{\text{argmax}} = 1 / \sum e^{x_i - \max(x)}$, we can compute topk_p and topk_index without materializing the full softmax tensor in HBM. This avoids the accuracy collapse caused by hardcoding topk_p = 1.0 while retaining the performance benefits of skipping the full-vocab softmax.

Suggested change
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
if self.topk == 1:
max_logits, ret.topk_index = torch.max(ret.next_token_logits, dim=-1, keepdim=True)
ret.topk_p = 1.0 / torch.exp(ret.next_token_logits - max_logits).sum(dim=-1, keepdim=True)
else:
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)

Comment on lines +486 to +487
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can optimize the top-1 probability computation here as well when self.topk == 1 by using the mathematical identity $\text{softmax}(x)_{\text{argmax}} = 1 / \sum e^{x_i - \max(x)}$. This avoids the memory bandwidth overhead of full-vocab softmax and fast_topk without causing any accuracy degradation.

            if self.topk == 1:
                max_logits, topk_index = torch.max(logits_output.next_token_logits, dim=-1, keepdim=True)
                topk_p = 1.0 / torch.exp(logits_output.next_token_logits - max_logits).sum(dim=-1, keepdim=True)
            else:
                probs = torch.softmax(logits_output.next_token_logits, dim=-1)
                topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)

Comment on lines +654 to +655
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Apply the same top-1 softmax optimization here to avoid full-vocab softmax and fast_topk overhead when self.topk == 1.

Suggested change
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
if self.topk == 1:
max_logits, ret_topk_index = torch.max(draft_logits_output.next_token_logits, dim=-1, keepdim=True)
ret_topk_p = 1.0 / torch.exp(draft_logits_output.next_token_logits - max_logits).sum(dim=-1, keepdim=True)
else:
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants