From 62ff4af9904230049a5066e937e5e2074fe2f1c6 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Tue, 26 May 2026 03:10:55 -0700 Subject: [PATCH 1/2] =?UTF-8?q?Revert=20"Revert=20"[perf][spec=20decoding]?= =?UTF-8?q?=20Skip=20full-vocab=20softmax=20in=20EAGLE=20draft=20=E2=80=A6?= =?UTF-8?q?"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 9409969fd5a0089017ad20d2ce0760afe0a02ea8. --- .../eagle_draft_extend_cuda_graph_runner.py | 10 +++++++-- .../sglang/srt/speculative/eagle_worker_v2.py | 22 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index ad17631bcc88..8798086d147b 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -401,8 +401,14 @@ def run_once(): forward_batch.positions, forward_batch, ) - probs = torch.softmax(ret.next_token_logits, dim=-1) - ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret.topk_index = torch.argmax( + ret.next_token_logits, dim=-1, keepdim=True + ) + ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) + else: + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index eb5c40a7791b..c864dc12dc76 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -483,8 +483,16 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + # topk=1 → degenerate single-path tree; `topk_p` is unused + # downstream, so skip softmax and just argmax over logits. + topk_index = torch.argmax( + logits_output.next_token_logits, dim=-1, keepdim=True + ) + topk_p = torch.ones_like(topk_index, dtype=torch.float32) + else: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) maybe_detect_oob( topk_index, 0, @@ -651,8 +659,14 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret_topk_index = torch.argmax( + draft_logits_output.next_token_logits, dim=-1, keepdim=True + ) + ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32) + else: + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states # Construct the return values From 86702aa2d68819265dcef261f9fb947a1e66c7b5 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Tue, 26 May 2026 20:08:54 +0000 Subject: [PATCH 2/2] fix --- .../speculative/eagle_draft_extend_cuda_graph_runner.py | 8 +++++++- python/sglang/srt/speculative/eagle_worker_v2.py | 9 +++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 8798086d147b..d2992024eb08 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -30,12 +30,15 @@ from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( + is_hip, require_attn_tp_gather, require_gathered_buffer, require_mlp_sync, require_mlp_tp_gather, ) +_is_hip = is_hip() + if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -401,7 +404,10 @@ def run_once(): forward_batch.positions, forward_batch, ) - if self.topk == 1: + # ROCm's argmax tie-breaks differently from CUDA's softmax+max + # path on FP8 logits, which corrupts MTP draft selection on AMD. + # Keep the fastpath CUDA-only. + if self.topk == 1 and not _is_hip: ret.topk_index = torch.argmax( ret.next_token_logits, dim=-1, keepdim=True ) diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index c864dc12dc76..46f98bcb6cbb 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -483,9 +483,12 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - if self.topk == 1: + if self.topk == 1 and not _is_hip: # topk=1 → degenerate single-path tree; `topk_p` is unused # downstream, so skip softmax and just argmax over logits. + # Gated to CUDA: on ROCm the argmax tie-break diverges from + # the softmax+max path on FP8 logits and corrupts MTP draft + # selection (DSV3.2 MTP GSM8K, see #26358). topk_index = torch.argmax( logits_output.next_token_logits, dim=-1, keepdim=True ) @@ -659,7 +662,9 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - if self.topk == 1: + if self.topk == 1 and not _is_hip: + # Gated to CUDA: see #26358 — ROCm's argmax tie-break corrupts + # MTP draft selection on FP8 logits. ret_topk_index = torch.argmax( draft_logits_output.next_token_logits, dim=-1, keepdim=True )