diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index ad17631bcc88..bf4c865e09eb 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -30,12 +30,15 @@ from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( + is_hip, require_attn_tp_gather, require_gathered_buffer, require_mlp_sync, require_mlp_tp_gather, ) +_is_hip = is_hip() + if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -401,8 +404,18 @@ def run_once(): forward_batch.positions, forward_batch, ) - probs = torch.softmax(ret.next_token_logits, dim=-1) - ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + ret.topk_index = torch.argmax( + ret.next_token_logits, dim=-1, keepdim=True + ) + ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) + else: + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index eb5c40a7791b..8c6d3bd8a53d 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -483,8 +483,18 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + topk_index = torch.argmax( + logits_output.next_token_logits, dim=-1, keepdim=True + ) + topk_p = torch.ones_like(topk_index, dtype=torch.float32) + else: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) maybe_detect_oob( topk_index, 0, @@ -651,8 +661,18 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1 and not _is_hip: + # topk=1 → degenerate single-path tree; skip full-vocab softmax + # and use argmax(logits) directly. Gated off on ROCm/HIP because + # the MTP draft path is sensitive to whether topk_p is the true + # probability or a placeholder; see #26358 (revert) / R108. + ret_topk_index = torch.argmax( + draft_logits_output.next_token_logits, dim=-1, keepdim=True + ) + ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32) + else: + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states # Construct the return values