From 7b671d9af31ef3ccc7d2b12a693287b5a899fcc6 Mon Sep 17 00:00:00 2001 From: "xuyongfei.xyf" Date: Mon, 22 Dec 2025 11:27:05 +0800 Subject: [PATCH 1/4] opt ds32 decode with mtp --- .../srt/layers/attention/nsa_backend.py | 165 +++++++++--------- .../srt/model_executor/cuda_graph_runner.py | 17 +- .../sglang/srt/speculative/eagle_worker_v2.py | 6 +- 3 files changed, 94 insertions(+), 94 deletions(-) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index eb7d8cd0a8c1..bb5f40c1bf65 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -314,6 +314,19 @@ def __init__( model_runner.server_args.speculative_num_draft_tokens ) self.speculative_step_id = speculative_step_id + if self.speculative_num_draft_tokens is not None: + self.spec_tokens_range = torch.arange( + 1, + self.speculative_num_draft_tokens + 1, + dtype=torch.int32, + device=self.device, + ) + self.spec_tokens_neg_range = torch.arange( + -self.speculative_num_draft_tokens + 1, + 1, + dtype=torch.int32, + device=self.device, + ) self.device_capability = torch.cuda.get_device_capability() self.device_sm_major = self.device_capability[0] @@ -439,21 +452,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): dtype=torch.int32, device=device, ) - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=device, - ) - for qo_len, kv_len in zip( - forward_batch.extend_seq_lens_cpu, - forward_batch.seq_lens_cpu.tolist(), - strict=True, - ) - ] - ) if forward_batch.forward_mode.is_draft_extend_v2(): # DRAFT_EXTEND_V2: V2 worker pre-fills draft KV cache with ALL speculated # tokens upfront. All requests extend by the same fixed @@ -461,6 +459,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=self.speculative_num_draft_tokens, dim=0 ) + seqlens_expanded = ( + forward_batch.seq_lens.to(torch.int32).view(-1, 1) + + self.spec_tokens_neg_range + ).view(-1) else: # DRAFT_EXTEND (v1): V1 worker extends by (accept_length + 1) per request # after verification. Lengths vary per request based on how many tokens @@ -468,6 +470,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=forward_batch.extend_seq_lens, dim=0 ) + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=device, + ) + for qo_len, kv_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.seq_lens_cpu.tolist(), + strict=True, + ) + ] + ) elif forward_batch.forward_mode.is_extend(): assert ( forward_batch.extend_seq_lens_cpu is not None @@ -824,28 +841,9 @@ def init_forward_metadata_capture_cuda_graph( dtype=torch.int32, device=self.device, ) - - extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs - - seqlens_int32_cpu = [ - self.speculative_num_draft_tokens + kv_len - for kv_len in seq_lens.tolist() - ] - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=self.device, - ) - for qo_len, kv_len in zip( - extend_seq_lens_cpu, - seqlens_int32_cpu, - strict=True, - ) - ] - ) + seqlens_expanded = ( + seq_lens.to(torch.int32).view(-1, 1) + self.spec_tokens_range + ).view(-1) nsa_cache_seqlens_int32 = compute_nsa_seqlens( seqlens_expanded, nsa_index_topk=self.nsa_index_topk ) @@ -968,27 +966,7 @@ def init_forward_metadata_replay_cuda_graph( page_indices, repeats=self.speculative_num_draft_tokens, dim=0 ) metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices) - extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs - - seqlens_int32_cpu = [ - self.speculative_num_draft_tokens + kv_len - for kv_len in seq_lens_cpu.tolist() - ] - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=self.device, - ) - for qo_len, kv_len in zip( - extend_seq_lens_cpu, - seqlens_int32_cpu, - strict=True, - ) - ] - ) + seqlens_expanded = (seq_lens.view(-1, 1) + self.spec_tokens_range).view(-1) metadata.nsa_seqlens_expanded.copy_(seqlens_expanded) nsa_cache_seqlens = compute_nsa_seqlens( seqlens_expanded, self.nsa_index_topk @@ -1001,33 +979,39 @@ def init_forward_metadata_replay_cuda_graph( metadata.cu_seqlens_k[1:].copy_( torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) ) - - extend_seq_lens = spec_info.accept_length[:bs] - extend_seq_lens_cpu = extend_seq_lens.tolist() - page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] - page_indices = torch.repeat_interleave( - page_indices, repeats=extend_seq_lens, dim=0 - ) + if forward_mode.is_draft_extend_v2(): + page_indices = torch.repeat_interleave( + page_indices, repeats=self.speculative_num_draft_tokens, dim=0 + ) + seqlens_expanded = ( + seq_lens.to(torch.int32).view(-1, 1) + self.spec_tokens_neg_range + ).view(-1) + else: + extend_seq_lens = spec_info.accept_length[:bs] + extend_seq_lens_cpu = extend_seq_lens.tolist() + + page_indices = torch.repeat_interleave( + page_indices, repeats=extend_seq_lens, dim=0 + ) + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seq_lens_cpu.tolist(), + strict=True, + ) + ] + ) metadata.page_table_1[: page_indices.shape[0], :max_seqlen_k].copy_( page_indices ) - - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=self.device, - ) - for qo_len, kv_len in zip( - extend_seq_lens_cpu, - seq_lens_cpu.tolist(), - strict=True, - ) - ] - ) metadata.nsa_seqlens_expanded[: seqlens_expanded.shape[0]].copy_( seqlens_expanded ) @@ -1263,7 +1247,16 @@ def forward_extend( page_size=1, ) - if self.nsa_prefill_impl == "tilelang": + nsa_impl = ( + self.nsa_decode_impl + if ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ) + else self.nsa_prefill_impl + ) + + if nsa_impl == "tilelang": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( @@ -1273,7 +1266,7 @@ def forward_extend( sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif self.nsa_prefill_impl == "flashmla_sparse": + elif nsa_impl == "flashmla_sparse": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) @@ -1297,7 +1290,7 @@ def forward_extend( sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif self.nsa_prefill_impl == "flashmla_kv": + elif nsa_impl == "flashmla_kv": if q_rope is not None: q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( @@ -1310,7 +1303,7 @@ def forward_extend( metadata=metadata, page_table_1=page_table_1, ) - elif self.nsa_prefill_impl == "fa3": + elif nsa_impl == "fa3": return self._forward_fa3( q_rope=q_rope, kv_cache=kv_cache, @@ -1326,7 +1319,7 @@ def forward_extend( page_size=1, ) else: - raise ValueError(f"Unsupported {self.nsa_prefill_impl = }") + raise ValueError(f"Unsupported {nsa_impl = }") def forward_decode( self, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6be1b11c68a8..607af2a77761 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -186,7 +186,7 @@ def set_torch_compile_config(): monkey_patch_torch_compile() -def get_batch_sizes_to_capture(model_runner: ModelRunner): +def get_batch_sizes_to_capture(model_runner: ModelRunner, num_tokens_per_bs=1): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs @@ -203,7 +203,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if require_gathered_buffer(server_args): mul_base *= get_attention_tp_size() - capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] + capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0] capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] capture_bs = list(sorted(set(capture_bs))) @@ -267,11 +267,6 @@ def __init__(self, model_runner: ModelRunner): self.dllm_config = DllmConfig.from_server_args(model_runner.server_args) self.is_dllm = self.dllm_config is not None - # Batch sizes to capture - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") - if KTRANSFORMERS_AVAILABLE: - KTMoEWrapper.set_capture_batch_sizes(self.capture_bs) self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 @@ -291,6 +286,14 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture( + model_runner, self.num_tokens_per_bs + ) + log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") + if KTRANSFORMERS_AVAILABLE: + KTMoEWrapper.set_capture_batch_sizes(self.capture_bs) + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup if model_runner.server_args.enable_return_hidden_states: self.capture_hidden_mode = CaptureHiddenMode.FULL diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index d3dcc2afd5a5..663a03bf8a74 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,6 +12,7 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) +from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnMultiStepBackend from sglang.srt.layers.attention.triton_backend import TritonMultiStepDraftBackend from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.moe.utils import ( @@ -275,7 +276,10 @@ def init_cuda_graphs(self): _is_npu or ( _is_cuda - and isinstance(self.draft_attn_backend, TritonMultiStepDraftBackend) + and isinstance( + self.draft_attn_backend, + (TritonMultiStepDraftBackend, NativeSparseAttnMultiStepBackend), + ) ) ): tic = time.perf_counter() From 10e521eef84859b5aaa41a1e56a75fe9beb0d02d Mon Sep 17 00:00:00 2001 From: "xuyongfei.xyf" Date: Mon, 12 Jan 2026 21:08:20 +0800 Subject: [PATCH 2/4] revert specv2 opt --- .../srt/layers/attention/nsa_backend.py | 146 ++++++++++-------- .../sglang/srt/speculative/eagle_worker_v2.py | 6 +- 2 files changed, 82 insertions(+), 70 deletions(-) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index bb5f40c1bf65..55ca89dd194f 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -314,19 +314,6 @@ def __init__( model_runner.server_args.speculative_num_draft_tokens ) self.speculative_step_id = speculative_step_id - if self.speculative_num_draft_tokens is not None: - self.spec_tokens_range = torch.arange( - 1, - self.speculative_num_draft_tokens + 1, - dtype=torch.int32, - device=self.device, - ) - self.spec_tokens_neg_range = torch.arange( - -self.speculative_num_draft_tokens + 1, - 1, - dtype=torch.int32, - device=self.device, - ) self.device_capability = torch.cuda.get_device_capability() self.device_sm_major = self.device_capability[0] @@ -452,6 +439,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): dtype=torch.int32, device=device, ) + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=device, + ) + for qo_len, kv_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.seq_lens_cpu.tolist(), + strict=True, + ) + ] + ) if forward_batch.forward_mode.is_draft_extend_v2(): # DRAFT_EXTEND_V2: V2 worker pre-fills draft KV cache with ALL speculated # tokens upfront. All requests extend by the same fixed @@ -459,10 +461,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=self.speculative_num_draft_tokens, dim=0 ) - seqlens_expanded = ( - forward_batch.seq_lens.to(torch.int32).view(-1, 1) - + self.spec_tokens_neg_range - ).view(-1) else: # DRAFT_EXTEND (v1): V1 worker extends by (accept_length + 1) per request # after verification. Lengths vary per request based on how many tokens @@ -470,21 +468,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=forward_batch.extend_seq_lens, dim=0 ) - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=device, - ) - for qo_len, kv_len in zip( - forward_batch.extend_seq_lens_cpu, - forward_batch.seq_lens_cpu.tolist(), - strict=True, - ) - ] - ) elif forward_batch.forward_mode.is_extend(): assert ( forward_batch.extend_seq_lens_cpu is not None @@ -841,9 +824,28 @@ def init_forward_metadata_capture_cuda_graph( dtype=torch.int32, device=self.device, ) - seqlens_expanded = ( - seq_lens.to(torch.int32).view(-1, 1) + self.spec_tokens_range - ).view(-1) + + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) nsa_cache_seqlens_int32 = compute_nsa_seqlens( seqlens_expanded, nsa_index_topk=self.nsa_index_topk ) @@ -966,7 +968,27 @@ def init_forward_metadata_replay_cuda_graph( page_indices, repeats=self.speculative_num_draft_tokens, dim=0 ) metadata.page_table_1[:, :max_seqlen_k].copy_(page_indices) - seqlens_expanded = (seq_lens.view(-1, 1) + self.spec_tokens_range).view(-1) + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * bs + + seqlens_int32_cpu = [ + self.speculative_num_draft_tokens + kv_len + for kv_len in seq_lens_cpu.tolist() + ] + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seqlens_int32_cpu, + strict=True, + ) + ] + ) metadata.nsa_seqlens_expanded.copy_(seqlens_expanded) nsa_cache_seqlens = compute_nsa_seqlens( seqlens_expanded, self.nsa_index_topk @@ -979,39 +1001,33 @@ def init_forward_metadata_replay_cuda_graph( metadata.cu_seqlens_k[1:].copy_( torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) ) - page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] - if forward_mode.is_draft_extend_v2(): - page_indices = torch.repeat_interleave( - page_indices, repeats=self.speculative_num_draft_tokens, dim=0 - ) - seqlens_expanded = ( - seq_lens.to(torch.int32).view(-1, 1) + self.spec_tokens_neg_range - ).view(-1) - else: - extend_seq_lens = spec_info.accept_length[:bs] - extend_seq_lens_cpu = extend_seq_lens.tolist() - page_indices = torch.repeat_interleave( - page_indices, repeats=extend_seq_lens, dim=0 - ) - seqlens_expanded = torch.cat( - [ - torch.arange( - kv_len - qo_len + 1, - kv_len + 1, - dtype=torch.int32, - device=self.device, - ) - for qo_len, kv_len in zip( - extend_seq_lens_cpu, - seq_lens_cpu.tolist(), - strict=True, - ) - ] - ) + extend_seq_lens = spec_info.accept_length[:bs] + extend_seq_lens_cpu = extend_seq_lens.tolist() + + page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] + page_indices = torch.repeat_interleave( + page_indices, repeats=extend_seq_lens, dim=0 + ) metadata.page_table_1[: page_indices.shape[0], :max_seqlen_k].copy_( page_indices ) + + seqlens_expanded = torch.cat( + [ + torch.arange( + kv_len - qo_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + for qo_len, kv_len in zip( + extend_seq_lens_cpu, + seq_lens_cpu.tolist(), + strict=True, + ) + ] + ) metadata.nsa_seqlens_expanded[: seqlens_expanded.shape[0]].copy_( seqlens_expanded ) diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 663a03bf8a74..d3dcc2afd5a5 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,7 +12,6 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) -from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnMultiStepBackend from sglang.srt.layers.attention.triton_backend import TritonMultiStepDraftBackend from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.moe.utils import ( @@ -276,10 +275,7 @@ def init_cuda_graphs(self): _is_npu or ( _is_cuda - and isinstance( - self.draft_attn_backend, - (TritonMultiStepDraftBackend, NativeSparseAttnMultiStepBackend), - ) + and isinstance(self.draft_attn_backend, TritonMultiStepDraftBackend) ) ): tic = time.perf_counter() From 0eef2ed5b4a479228719b0c509dfe535971226a3 Mon Sep 17 00:00:00 2001 From: "xuyongfei.xyf" Date: Mon, 12 Jan 2026 21:31:38 +0800 Subject: [PATCH 3/4] include v2 --- python/sglang/srt/layers/attention/nsa_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 55ca89dd194f..8d0b00a136d8 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1267,7 +1267,7 @@ def forward_extend( self.nsa_decode_impl if ( forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) ) else self.nsa_prefill_impl ) From 7adb04612c2d8043aff31900192a421f83228fb8 Mon Sep 17 00:00:00 2001 From: "xuyongfei.xyf" Date: Fri, 16 Jan 2026 22:29:19 +0800 Subject: [PATCH 4/4] add comment --- python/sglang/srt/model_executor/cuda_graph_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 607af2a77761..55a23e333d19 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -199,10 +199,12 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner, num_tokens_per_bs=1): if server_args.enable_two_batch_overlap: mul_base *= 2 + num_tokens_per_bs = 1 # tbo not test, set num_tokens_per_bs to 1 if require_gathered_buffer(server_args): mul_base *= get_attention_tp_size() + # Model input token count = bs * num_tokens_per_bs; must be a multiple of attn_tp_size. capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0] capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]