diff --git a/docs/basic_usage/deepseek_v32.md b/docs/basic_usage/deepseek_v32.md index 6035f8911a63..b94e8472a3e8 100644 --- a/docs/basic_usage/deepseek_v32.md +++ b/docs/basic_usage/deepseek_v32.md @@ -71,6 +71,10 @@ python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp - The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. - The default value of `--max-running-requests` is set to `48` for MTP. For larger batch sizes, this value should be increased beyond the default value. +```{tip} +To enable the experimental overlap scheduler for EAGLE speculative decoding, set the environment variable `SGLANG_ENABLE_SPEC_V2=1`. This can improve performance by enabling overlap scheduling between draft and verification stages. +``` + ## Function Calling and Reasoning Parser The usage of function calling and reasoning parser is the same as DeepSeek V3.1. Please refer to [Reasoning Parser](https://docs.sglang.io/advanced_features/separate_reasoning.html) and [Tool Parser](https://docs.sglang.io/advanced_features/tool_parser.html) documents. diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 2b2ae7e52e9e..85be163e3e5b 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -295,7 +295,7 @@ def _get_topk_paged( blocksize = page_size if ( forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): seqlens_32 = metadata.get_seqlens_expanded() else: @@ -900,7 +900,7 @@ def forward_cuda( if ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): topk_result = self._get_topk_paged( forward_batch, layer_id, q_fp8, weights, metadata diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index ed7f241789b6..7ea4322c2fe2 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -389,7 +389,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=self.speculative_num_draft_tokens, dim=0 ) - elif forward_batch.forward_mode.is_draft_extend(): + elif forward_batch.forward_mode.is_draft_extend(include_v2=True): assert ( forward_batch.extend_seq_lens_cpu is not None and forward_batch.extend_seq_lens is not None @@ -422,9 +422,20 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) ] ) - page_table = torch.repeat_interleave( - page_table, repeats=forward_batch.extend_seq_lens, dim=0 - ) + if forward_batch.forward_mode.is_draft_extend_v2(): + # DRAFT_EXTEND_V2: V2 worker pre-fills draft KV cache with ALL speculated + # tokens upfront. All requests extend by the same fixed + # (speculative_num_draft_tokens). Use scalar to avoid GPU sync. + page_table = torch.repeat_interleave( + page_table, repeats=self.speculative_num_draft_tokens, dim=0 + ) + else: + # DRAFT_EXTEND (v1): V1 worker extends by (accept_length + 1) per request + # after verification. Lengths vary per request based on how many tokens + # were accepted. + page_table = torch.repeat_interleave( + page_table, repeats=extend_seq_lens_cpu, dim=0 + ) elif forward_batch.forward_mode.is_extend(): assert ( @@ -632,7 +643,9 @@ def init_forward_metadata_capture_cuda_graph( ) else: flashmla_metadata = None - elif forward_mode.is_target_verify() or forward_mode.is_draft_extend(): + elif forward_mode.is_target_verify() or forward_mode.is_draft_extend( + include_v2=True + ): cache_seqlens_int32 = (seq_lens + self.speculative_num_draft_tokens).to( torch.int32 ) @@ -796,7 +809,7 @@ def init_forward_metadata_replay_cuda_graph( seqlens_expanded, self.nsa_index_topk ) metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens) - elif forward_mode.is_draft_extend(): + elif forward_mode.is_draft_extend(include_v2=True): max_seqlen_k = int(seq_lens_cpu.max().item()) cache_seqlens = seq_lens.to(torch.int32) metadata.cache_seqlens_int32.copy_(cache_seqlens)