From 0b9af3b4b9839adf8665bf3f21bc3a203be51cfc Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Sun, 24 May 2026 22:14:31 +0000 Subject: [PATCH] [Spec V2] Skip full-vocab softmax in EAGLE draft when topk == 1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The greedy spec-decoding path (`--speculative-eagle-topk 1`) currently runs a full-vocab `torch.softmax` + `torch.max` for every draft step and again for `_draft_extend_for_decode` (both eager and inside the captured draft-extend cuda graph). With topk == 1 the draft tree is a single path, so `topk_p` does not feed back into any ranking decision (see `spec_utils._select_top_k_tokens_later` — the scores are multiplied along a degenerate single branch). `topk_index = argmax(logits)` is identical to `argmax(softmax(logits))`, so the softmax is purely wasted work. Profile (Kimi-K2.5-NVFP4 / TP=4 / 80K ctx / EAGLE3 3-step / bs=1): `cunn_SoftMaxForward` was ~43 µs/call. It fired 2× per DRAFT_DECODE (steps 0 and 1 of the loop), 1× per `_draft_extend_for_decode`, and 1× inside the captured DRAFT_EXTEND graph — ~175 µs/cycle total. After this change all three call sites use argmax with a constant `topk_p = ones`. Patched sites: - `eagle_worker_v2.py:draft_forward` — inner draft step loop, runs inside the captured DRAFT_DECODE cuda graph. - `eagle_worker_v2.py:_draft_extend_for_decode` — post-graph reorganization after the draft-extend replay. - `eagle_draft_extend_cuda_graph_runner.py:capture_one_batch_size` — the softmax+topk burned into the DRAFT_EXTEND cuda graph itself. All three are gated by `if self.topk == 1`; multi-path tree behavior (topk > 1) is unchanged. Measured on the canonical workload (10 prompts, max-concurrency=1, no `SGLANG_SIMULATE_ACC_LEN`): metric baseline patched delta Mean TPOT 2.41 ms 2.36 ms -0.05 ms Med TPOT 2.37 ms 2.34 ms -0.03 ms 1000/Mean 414.9 423.7 +8.8 tok/s (+2.1%) 1000/Med 421.9 427.4 +5.5 tok/s (+1.3%) accept_length 3.92 3.94 unchanged (within noise) GPU-side `cunn_SoftMaxForward` count drops to 0 in both DRAFT_DECODE and DRAFT_EXTEND kernel breakdowns. --- .../eagle_draft_extend_cuda_graph_runner.py | 10 +++++++-- .../sglang/srt/speculative/eagle_worker_v2.py | 22 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 9df3742fcf28..e588422782b5 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -401,8 +401,14 @@ def run_once(): forward_batch.positions, forward_batch, ) - probs = torch.softmax(ret.next_token_logits, dim=-1) - ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret.topk_index = torch.argmax( + ret.next_token_logits, dim=-1, keepdim=True + ) + ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32) + else: + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index a3d14af9122b..56d41ebcc1de 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -479,8 +479,16 @@ def draft_forward(self, forward_batch: ForwardBatch): forward_batch, skip_attn_backend_init=True ).logits_output maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") - probs = torch.softmax(logits_output.next_token_logits, dim=-1) - topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + # topk=1 → degenerate single-path tree; `topk_p` is unused + # downstream, so skip softmax and just argmax over logits. + topk_index = torch.argmax( + logits_output.next_token_logits, dim=-1, keepdim=True + ) + topk_p = torch.ones_like(topk_index, dtype=torch.float32) + else: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) maybe_detect_oob( topk_index, 0, @@ -647,8 +655,14 @@ def _draft_extend_for_decode( draft_logits_output.hidden_states = draft_logits_output.hidden_states[ select_index ] - probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) - ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + if self.topk == 1: + ret_topk_index = torch.argmax( + draft_logits_output.next_token_logits, dim=-1, keepdim=True + ) + ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32) + else: + probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_hidden_states = draft_logits_output.hidden_states # Construct the return values