diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 9df3742fcf28..e588422782b5 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -401,8 +401,14 @@ def run_once(): forward_batch.positions, forward_batch, ) - probs = torch.softmax(ret.next_token_logits, dim=-1) - ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret.topk_index = torch.argmax( + ret.next_token_logits, dim=-1, keepdim=True + ) + ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) + else: + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index a3d14af9122b..56d41ebcc1de 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -479,8 +479,16 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + # topk=1 → degenerate single-path tree; `topk_p` is unused + # downstream, so skip softmax and just argmax over logits. + topk_index = torch.argmax( + logits_output.next_token_logits, dim=-1, keepdim=True + ) + topk_p = torch.ones_like(topk_index, dtype=torch.float32) + else: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) maybe_detect_oob( topk_index, 0, @@ -647,8 +655,14 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret_topk_index = torch.argmax( + draft_logits_output.next_token_logits, dim=-1, keepdim=True + ) + ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32) + else: + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states # Construct the return values