From 9d3679bfa9e9fa9d74236ae8c9448074b481ef65 Mon Sep 17 00:00:00 2001 From: Michael <13900043+michaelzhang-ai@users.noreply.github.com> Date: Tue, 26 May 2026 08:39:01 -0500 Subject: [PATCH] [perf][spec decoding] Re-land #26235 (skip EAGLE topk==1 softmax) for CUDA only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Re-lands the perf optimization from #26235 — which skips the full-vocab softmax + fast_topk when `self.topk == 1` and uses `argmax(logits)` with a placeholder `topk_p = ones` — but gates it OFF on ROCm/HIP, where it was the cause of R108 (DSv3.2 + MTP gsm8k accuracy collapse to 0.035 with ~96% invalid output). ## Why gate, not redo Per verification on the revert PR (#26358), AMD MTP draft paths consume `topk_p` somewhere downstream in a way that depends on it being the actual softmax probability, not a placeholder. The exact downstream read site has not been identified; gating off on HIP is the zero-correctness-risk way to preserve the CUDA perf win while keeping AMD safe. Evidence the gate is sufficient: - Reverting all 3 sites recovered DSv3.2-MTP gsm8k on ROCm 7.2 from 0.035 → 0.975 ([revert verify](https://github.com/sgl-project/sglang/actions/runs/26438872740/job/77828088922)). - Pre-#26235 (with full-vocab softmax) was the historical green state on AMD for weeks; restoring that branch on HIP returns to known-good. ## What the gate looks like 3 sites get the same shape (verbatim across files): if self.topk == 1 and not _is_hip: # topk=1 → degenerate single-path tree; skip full-vocab softmax # and use argmax(logits) directly. Gated off on ROCm/HIP because # the MTP draft path is sensitive to whether topk_p is the true # probability or a placeholder; see #26358 (revert) / R108. ret.topk_index = torch.argmax( ret.next_token_logits, dim=-1, keepdim=True ) ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) else: probs = torch.softmax(ret.next_token_logits, dim=-1) ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) `_is_hip` is already module-scope in eagle_worker_v2.py; added a parallel module-scope binding in eagle_draft_extend_cuda_graph_runner.py. ## Files - python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py (+ is_hip import, + _is_hip module binding, + 1 site gate) - python/sglang/srt/speculative/eagle_worker_v2.py (+ 2 site gates) ## Tested - Lint clean (no new ruff/flake8 issues). - AMD CI on the parent commit (origin/main = `a26913158`) is the baseline to compare against — this PR is a no-op on HIP at runtime, so AMD-side CI behavior should match origin/main exactly. ## References - Original PR (reverted): #26235 by @Qiaolin-Yu - Revert PR (merged 2026-05-26): #26358 - CI tracker: https://github.com/bingxche/sglang-ci-bot/issues/84 (R108) cc @Qiaolin-Yu (original author) — this is a re-land of your perf optimization with the AMD safety gate we discussed in #26358. --- .../eagle_draft_extend_cuda_graph_runner.py | 17 +++++++++-- .../sglang/srt/speculative/eagle_worker_v2.py | 28 ++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index ad17631bcc88..bf4c865e09eb 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -30,12 +30,15 @@ from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( + is_hip, require_attn_tp_gather, require_gathered_buffer, require_mlp_sync, require_mlp_tp_gather, ) +_is_hip = is_hip() + if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -401,8 +404,18 @@ def run_once(): forward_batch.positions, forward_batch, ) - probs = torch.softmax(ret.next_token_logits, dim=-1) - ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + ret.topk_index = torch.argmax( + ret.next_token_logits, dim=-1, keepdim=True + ) + ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) + else: + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index eb5c40a7791b..8c6d3bd8a53d 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -483,8 +483,18 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + topk_index = torch.argmax( + logits_output.next_token_logits, dim=-1, keepdim=True + ) + topk_p = torch.ones_like(topk_index, dtype=torch.float32) + else: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) maybe_detect_oob( topk_index, 0, @@ -651,8 +661,18 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + ret_topk_index = torch.argmax( + draft_logits_output.next_token_logits, dim=-1, keepdim=True + ) + ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32) + else: + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states # Construct the return values