diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 449ddaea10dc..e554c028686f 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -404,6 +404,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-layerwise-nvtx-marker` | Enable layerwise NVTX profiling annotations for the model. This adds NVTX markers to every layer for detailed per-layer performance analysis with Nsight Systems. | `False` | bool flag (set to enable) | | `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) | | `--enable-nsa-prefill-context-parallel` | Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 | `False` | bool flag (set to enable) | +| `--nsa-prefill-cp-mode` | Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: `in-seq-split` (default), `round-robin-split`. `round-robin-split` distributes tokens across ranks based on `token_idx % cp_size`. It supports multi-batch prefill, fused MoE, and FP8 KV cache. | `in-seq-split` | Type: str | ## Forward hooks | Argument | Description | Defaults | Options | diff --git a/docs/basic_usage/deepseek_v32.md b/docs/basic_usage/deepseek_v32.md index 955bfaed0d3d..fd81f253509c 100644 --- a/docs/basic_usage/deepseek_v32.md +++ b/docs/basic_usage/deepseek_v32.md @@ -290,3 +290,15 @@ Some features are still not supported at present. - **Other Args**: Currently only supports moe_dense_tp_size=1, kv_cache_dtype = "bf16", moe_a2a_backend = "deepep", - **DP_size**: `CP_size` reuses `atten_tp_size`, which is equal to `TP_size` / `DP_size`. For the cp function to work correctly, `TP_size` must be divisible by `DP_size`, and TP_size / DP_size > 1 (to ensure CP_size > 1). - **Detailed design reference**: https://github.com/sgl-project/sglang/pull/12065 + +### Alternative context parallel mode + +You can switch the CP token splitting mode for prefill by specifying the parameter `--nsa-prefill-cp-mode round-robin-split`. It distributes tokens across ranks based on `token_idx % cp_size`. +In this scenario, compared with the aforementioned method, it additionally supports the fused MoE backend (the fused MoE backend may deliver better performance than DeepEP in single-machine scenarios), +FP8 KV-cache, and multi-batch prefill inference. For more details, please refer to PR https://github.com/sgl-project/sglang/pull/13959. + +Example usage: +```bash +# Launch with FusedMoe + CP8 + DP1 +python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 1 --enable-dp-attention --enable-nsa-prefill-context-parallel --nsa-prefill-cp-mode round-robin-split --max-running-requests 32 +``` diff --git a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py index 6626c468f58f..dc811fcca70f 100644 --- a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py +++ b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py @@ -10,7 +10,7 @@ ) from sglang.srt.layers.attention.nsa.utils import ( cp_split_and_rebuild_position, - enable_prefill_cp, + nsa_use_prefill_cp, ) from sglang.srt.layers.communicator import get_attn_tp_context @@ -192,12 +192,12 @@ def forward_mla_prepare_npu( q_nope_out = q_nope_out.transpose(0, 1) - if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): positions = cp_split_and_rebuild_position(forward_batch, positions) q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe) - if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): # support allgather+rerrange k_nope, k_pe = m.rebuild_cp_kv_cache( latent_cache, forward_batch, k_nope, k_pe @@ -338,12 +338,12 @@ def forward_dsa_prepare_npu( q_nope_out = q_nope_out.transpose(0, 1) - if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): positions = cp_split_and_rebuild_position(forward_batch, positions) q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe) - if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp): # support allgather+rerrange k_nope, k_pe = m.rebuild_cp_kv_cache( latent_cache, forward_batch, k_nope, k_pe diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 51b9d4ff32c0..c9e82e4b12b9 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -28,6 +28,7 @@ NSA_DUAL_STREAM, cp_all_gather_rerange_output, is_nsa_enable_prefill_cp, + is_nsa_prefill_cp_in_seq_split, ) from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.linear import ReplicatedLinear @@ -63,6 +64,21 @@ def get_seqlens_expanded(self) -> torch.Tensor: Return: (sum_extend_seq_len,) int32 tensor """ + def get_indexer_kvcache_range(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Return: (tokens, ), (tokens, ) int32, k_start and k_end in kv cache(token,xxx) for each token. + """ + + def get_indexer_seq_len_cpu(self) -> torch.Tensor: + """ + Return: seq lens for each batch. + """ + + def get_token_to_batch_idx(self) -> torch.Tensor: + """ + Return: batch idx for each token. + """ + @abstractmethod def topk_transform( self, @@ -227,15 +243,6 @@ def _get_q_k_bf16( query[..., : self.rope_head_dim] = q_rope key[..., : self.rope_head_dim] = k_rope - # allgather+rerrange - if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp: - key = cp_all_gather_rerange_output( - key.contiguous(), - self.cp_size, - forward_batch, - torch.cuda.current_stream(), - ) - if enable_dual_stream: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) @@ -248,6 +255,14 @@ def _get_q_k_bf16( query = rotate_activation(query) key = rotate_activation(key) + # allgather+rerrange + if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp: + key = cp_all_gather_rerange_output( + key.contiguous(), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) return query, key def _get_k_bf16( @@ -373,15 +388,7 @@ def _get_topk_ragged( weights = weights.squeeze(-1) k_fp8_list = [] k_scale_list = [] - ks_list = [] - ke_list = [] - # Token-to-batch mapping for PAGED chunk alignment - token_to_batch_idx: List[int] = [] - - q_offset = 0 - k_offset = 0 - seq_lens_expanded = metadata.get_seqlens_expanded() block_tables = metadata.get_page_table_64() assert ( @@ -389,8 +396,19 @@ def _get_topk_ragged( and forward_batch.extend_seq_lens_cpu is not None ) - for i in range(forward_batch.batch_size): - seq_len = forward_batch.seq_lens_cpu[i].item() + batch_size = len(block_tables) + token_nums, _, _ = q_fp8.shape + device = q_fp8.device + topk_result = torch.full( + (token_nums, self.index_topk), -1, device=device, dtype=torch.int32 + ) + if batch_size == 0: + return topk_result + + indexer_seq_lens_cpu = metadata.get_indexer_seq_len_cpu() + assert len(indexer_seq_lens_cpu) == batch_size + for i in range(batch_size): + seq_len = indexer_seq_lens_cpu[i].item() assert isinstance(seq_len, int) # Use fused Triton kernel to get both K and scale in a single call k_fp8, k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_buffer( @@ -398,47 +416,23 @@ def _get_topk_ragged( seq_len, block_tables[i], ) - extend_seq_len = forward_batch.extend_seq_lens_cpu[i] - ks = torch.full( - (extend_seq_len,), k_offset, dtype=torch.int32, device="cuda" - ) - ke = ks + seq_lens_expanded[q_offset : q_offset + extend_seq_len] k_fp8_list.append(k_fp8) k_scale_list.append(k_scale) - ks_list.append(ks) - ke_list.append(ke) - - token_to_batch_idx.extend([i] * extend_seq_len) - q_offset += extend_seq_len - k_offset += seq_len k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) - ks = torch.cat(ks_list, dim=0) - ke = torch.cat(ke_list, dim=0) - - # Suppose there are two requests, with extend_seq_len = [3, 2] - # and seq_lens = [10, 4] - # The logits matrix looks like this, with * representing the valid logits - # and - representing the invalid logits: - # - # ********--|---- - # *********-|---- - # **********|---- - # ----------|***- - # ----------|**** - # - # ks = [0, 0, 0, 10, 10] - # ke = [8, 9, 10, 13, 14] - - token_nums, _, _ = q_fp8.shape - device = q_fp8.device + ks, ke = metadata.get_indexer_kvcache_range() + seq_lens_expanded = metadata.get_seqlens_expanded() + token_to_batch_idx = metadata.get_token_to_batch_idx() + q_offset = ks.shape[0] + k_offset = k_fp8.shape[0] # Check if we need to chunk to avoid OOM need_chunk, free_mem = self._should_chunk_mqa_logits(q_offset, k_offset, device) if not need_chunk: + assert q_fp8[:q_offset].shape[0] != 0 logits = deep_gemm.fp8_mqa_logits( q_fp8[:q_offset], kv_fp8, @@ -451,12 +445,6 @@ def _get_topk_ragged( assert logits.shape[1] == k_offset raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks) - topk_result = torch.full( - (token_nums, self.index_topk), - -1, - device=device, - dtype=torch.int32, - ) topk_result[:q_offset] = raw_topk_result return topk_result @@ -477,17 +465,6 @@ def _get_topk_ragged( global_topk_offset.shape[0] >= q_offset ), f"topk_indices_offset too short: {global_topk_offset.shape[0]} < {q_offset}" - topk_result = torch.full( - (token_nums, self.index_topk), -1, device=device, dtype=torch.int32 - ) - - # Only materialize batch index tensor when PAGED path needs it - token_to_batch_idx_tensor = None - if global_topk_offset is None: - token_to_batch_idx_tensor = torch.tensor( - token_to_batch_idx, dtype=torch.long, device=device - ) - start = 0 while start < q_offset: end = min(start + max_rows, q_offset) @@ -516,7 +493,7 @@ def _get_topk_ragged( cu_seqlens_q_chunk = torch.ones( B_chunk, dtype=torch.int32, device=device ) - batch_idx_chunk = token_to_batch_idx_tensor[start:end] + batch_idx_chunk = token_to_batch_idx[start:end] raw_topk_chunk = metadata.topk_transform( logits_chunk, @@ -911,7 +888,7 @@ def forward_cuda( else: if ( forward_batch.nsa_cp_metadata is not None - and self.nsa_enable_prefill_cp + and is_nsa_prefill_cp_in_seq_split() ): kv_len_prev = forward_batch.nsa_cp_metadata.kv_len_prev kv_len_next = forward_batch.nsa_cp_metadata.kv_len_next diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py index 9cbaa6c0ada0..94cb96e7c0e6 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -1,15 +1,26 @@ # temp NSA debugging environ from dataclasses import dataclass from itertools import accumulate -from typing import List +from typing import TYPE_CHECKING, List, Tuple, Union import torch import torch.nn.functional as F - -from sglang.srt.layers.dp_attention import get_attention_tp_group +import triton +import triton.language as tl + +from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, + get_attention_tp_group, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import get_bool_env_var +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true") NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true") @@ -41,6 +52,75 @@ def is_nsa_enable_prefill_cp(): return get_global_server_args().enable_nsa_prefill_context_parallel +def is_nsa_prefill_cp_in_seq_split(): + return ( + is_nsa_enable_prefill_cp() + and get_global_server_args().nsa_prefill_cp_mode == "in-seq-split" + ) + + +def is_nsa_prefill_cp_round_robin_split(): + return ( + is_nsa_enable_prefill_cp() + and get_global_server_args().nsa_prefill_cp_mode == "round-robin-split" + ) + + +def can_nsa_prefill_cp_round_robin_split(forward_batch: "ForwardBatch"): + if not forward_batch.forward_mode.is_context_parallel_extend(): + return False + cp_size = get_attention_tp_size() + seq_len = sum(forward_batch.extend_seq_lens_cpu) + return is_nsa_prefill_cp_round_robin_split() and seq_len > 0 and cp_size > 1 + + +def nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]): + """ + # for round-robin-split, split the tokens evenly according to the rule of token_idx % cp_size. + | +-----------before split------------+| + | token0, token1, token2, token3, token4, token5, token6, token7, ... + | + | +--------------result-------------------+ + | dp_atten_tp0: token0, token4, token8, token12, token16, ... | + | dp_atten_tp1: token1, token5, token9, token13, token17, ... | + | dp_atten_tp2: token2, token6, token10, token14, token18, ... | + | dp_atten_tp3: token3, token7, token11, token15, token19, ... | + | +-------------------------+ + """ + cp_size = get_attention_tp_size() + cp_rank = get_attention_tp_rank() + if isinstance(input_, (tuple, list)): + indices = range(cp_rank, len(input_), cp_size) + return input_[indices] + + tokens = len(input_) + if tokens % cp_size != 0: + cur_len = tokens // cp_size + (tokens % cp_size > cp_rank) + if cur_len == 0: + return input_.new_empty(0, *input_.shape[1:]) + indices = torch.arange(cp_rank, tokens, cp_size, device=input_.device) + return input_[indices] + + # for torch device tensor + return input_.view(-1, cp_size, *input_.shape[1:])[:, cp_rank].contiguous() + + +def pad_nsa_cache_seqlens(forward_batch: "ForwardBatch", nsa_cache_seqlens): + attn_tp_size = get_attention_tp_size() + if attn_tp_size == 1 or not can_nsa_prefill_cp_round_robin_split(forward_batch): + return nsa_cache_seqlens + tokens = sum(forward_batch.extend_seq_lens_cpu) + pad_len = (tokens - 1) // attn_tp_size + 1 - nsa_cache_seqlens.shape[0] + if pad_len > 0: + nsa_cache_seqlens = torch.cat( + [ + nsa_cache_seqlens, + nsa_cache_seqlens.new_zeros(pad_len, *nsa_cache_seqlens.shape[1:]), + ] + ) + return nsa_cache_seqlens + + @dataclass class NSAContextParallelMetadata: @@ -61,7 +141,17 @@ class NSAContextParallelMetadata: total_seq_lens: torch.Tensor = None -def can_cp_split(cur_cp_seq_len: int, cp_size: int, use_nsa: bool, forward_batch): +def can_cp_split(seq_len: int, cp_size: int, use_nsa: bool, forward_batch): + if is_nsa_prefill_cp_round_robin_split(): + cur_cp_seq_len = seq_len // cp_size + assert ( + seq_len % cp_size == 0 + ), f"seq_len {seq_len} is not divisible by cp_size {cp_size} when nsa_prefill_cp_mode is round-robin-split" + else: + # TODO current just support prefill batch=1 and len(input_ids) > self.cp_size * 2 + # Note: (self.cp_size * 2) To achieve load balancing for seq computation, + # the seq data needs to be divided and recombined at twice the size of cp_size. + cur_cp_seq_len = seq_len // (cp_size * 2) if ( cur_cp_seq_len != 0 and cp_size > 1 @@ -75,6 +165,13 @@ def can_cp_split(cur_cp_seq_len: int, cp_size: int, use_nsa: bool, forward_batch def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor): + if is_nsa_prefill_cp_round_robin_split(): + cp_size = get_attention_tp_size() + assert ( + input_.shape[0] % cp_size == 0 + ), f"Expect input shape 0 can divided by cp size, but got input shape {input_.shape}, cp size {cp_size}" + return nsa_cp_round_robin_split_data(input_) + input_list = list( torch.split(input_, forward_batch.nsa_cp_metadata.split_list, dim=0) ) @@ -85,6 +182,14 @@ def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor): def cp_split_and_rebuild_position(forward_batch, positions: torch.Tensor): + if is_nsa_prefill_cp_round_robin_split(): + cp_size = get_attention_tp_size() + assert positions.shape[0] % cp_size == 0, ( + f"Expect positions shape 0 can divided by cp size, but got positions shape {positions.shape}, " + f"cp size {cp_size}" + ) + return nsa_cp_round_robin_split_data(positions) + position_id_list = list( torch.split(positions, forward_batch.nsa_cp_metadata.split_list, dim=-1) ) @@ -95,7 +200,75 @@ def cp_split_and_rebuild_position(forward_batch, positions: torch.Tensor): return positions -def enable_prefill_cp(forward_batch, nsa_enable_prefill_cp): +@triton.jit +def nsa_cp_round_robin_split_q_seqs_kernel( + in_seqs_ptr, + out_seqs_ptr, + bs_idx_ptr, + tokens: tl.constexpr, + cp_size: tl.constexpr, + cp_rank: tl.constexpr, +): + extra_seq = 0 + bs_idx = 0 + for bs in range(tokens): + cur_len = tl.load(in_seqs_ptr + bs) + cur_len += extra_seq + cur_seq = cur_len // cp_size + (cur_len % cp_size > cp_rank) + if cur_seq > 0: + tl.store(bs_idx_ptr + bs_idx, bs) + tl.store(out_seqs_ptr + bs_idx, cur_seq) + bs_idx += 1 + extra_seq = cur_len - cur_seq * cp_size + + +def nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs): + cp_size = get_attention_tp_size() + cp_rank = get_attention_tp_rank() + extra_seq = 0 + q_seqs = [] + for bs, cur_len in enumerate(extend_seqs): + cur_len += extra_seq + cur_seq = cur_len // cp_size + int(cur_len % cp_size > cp_rank) + q_seqs.append(cur_seq) + extra_seq = cur_len - cur_seq * cp_size + bs_idx = list([i for i, x in enumerate(q_seqs) if x > 0]) + q_seqs = [q_len for q_len in q_seqs if q_len > 0] + return q_seqs, bs_idx + + +def nsa_cp_round_robin_split_q_seqs( + extend_seqs_cpu, extend_seqs +) -> Tuple[List, torch.Tensor, List, torch.Tensor]: + """ + round-robin-split distributes tokens across ranks based on token_idx % cp_size. + + Return: + ret_q_lens_cpu(List) and ret_q_lens(torch.Tensor): the partitioned length (excluding zeros) on the current cp rank + for each sequence after distribution across cp ranks. + bs_idx_cpu(List) and bs_idx(torch.Tensor): marks which sequences are ultimately selected, + i.e., those with a partitioned length greater than zero. + """ + cp_size = get_attention_tp_size() + cp_rank = get_attention_tp_rank() + # len(ret_q_lens_cpu) == len(bs_idx_cpu) + ret_q_lens_cpu, bs_idx_cpu = nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs_cpu) + ret_q_lens = torch.empty( + (len(bs_idx_cpu),), device=extend_seqs.device, dtype=extend_seqs.dtype + ) + bs_idx = torch.empty( + (len(bs_idx_cpu),), device=extend_seqs.device, dtype=torch.int32 + ) + grid = (1,) + nsa_cp_round_robin_split_q_seqs_kernel[grid]( + extend_seqs, ret_q_lens, bs_idx, len(extend_seqs), cp_size, cp_rank + ) + return ret_q_lens_cpu, ret_q_lens, bs_idx_cpu, bs_idx + + +def nsa_use_prefill_cp(forward_batch, nsa_enable_prefill_cp=None): + if nsa_enable_prefill_cp is None: + nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() if ( forward_batch.nsa_cp_metadata is not None and nsa_enable_prefill_cp @@ -149,6 +322,7 @@ def cp_attn_tp_all_gather_reorganazied_into_tensor( def cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream): """ + # for in-seq-split | +-----------before allgather------------+| | | dp_atten_tp0: block0, block7 | | | dp_atten_tp1: block1, block6 | @@ -161,7 +335,34 @@ def cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream): | +--------------result-------------------+ | block0 | block1 | block2 | block3 | block4 | block5 | block6 | block7 | | +-------------------------+ + + # for round-robin-split + | +-----------before allgather------------+| + | dp_atten_tp0: token0, token4, token8, token12, token16, ... | + | dp_atten_tp1: token1, token5, token9, token13, token17, ... | + | dp_atten_tp2: token2, token6, token10, token14, token18, ... | + | dp_atten_tp3: token3, token7, token11, token15, token19, ... | + | + | +--------------result-------------------+ + | token0, token1, token2, token3, token4, token5, token6, token7, ... + | +-------------------------+ """ + if is_nsa_prefill_cp_round_robin_split(): + output_tensor = input_tensor.new_empty( + (input_tensor.shape[0] * cp_size, *input_tensor.shape[1:]), + ) + attn_tp_all_gather_into_tensor( + output_tensor, + input_tensor, + ) + out_shape = output_tensor.shape + output_tensor = ( + output_tensor.view(cp_size, -1, *out_shape[1:]) + .transpose(0, 1) + .reshape(out_shape) + ) + return output_tensor + bs_seq_len, hidden_size = input_tensor.shape output_tensor = cp_attn_tp_all_gather_reorganazied_into_tensor( input_tensor, @@ -236,6 +437,8 @@ def prepare_input_dp_with_cp_dsa( cp_size, seqs_len, ): + if is_nsa_prefill_cp_round_robin_split(): + return True """prepare_input_dp_with_cp_dsa-zigzag index Example (DP_ATTENT_TP == CP_SIZE == 4): Description: @@ -274,6 +477,7 @@ def prepare_input_dp_with_cp_dsa( - To mitigate uneven load, the input hissenstate needs to be sliced by cp_size*2 and rearranged. """ # just support batch = 1 + kv_len = torch.tensor(kv_len) bs_per_cp_group = 1 kv_len_origin = kv_len # get zigzag index diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index a28ad66dcf50..5f26347c8a7e 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, TypeAlias import torch @@ -25,8 +25,12 @@ NSA_ENABLE_MTP_PRECOMPUTE_METADATA, NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8, NSA_FUSE_TOPK, + can_nsa_prefill_cp_round_robin_split, compute_nsa_seqlens, is_nsa_enable_prefill_cp, + nsa_cp_round_robin_split_data, + nsa_cp_round_robin_split_q_seqs, + pad_nsa_cache_seqlens, ) from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -125,6 +129,13 @@ class NSAMetadata: # shape: (seq_lens_sum,) topk_indices_offset: Optional[torch.Tensor] = None + # k_start and k_end in kv cache for each token. + indexer_k_start_end: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + # seq lens for each batch. + indexer_seq_lens_cpu: Optional[torch.Tensor] = None + # batch index for each token. + token_to_batch_idx: Optional[torch.Tensor] = None + class TopkTransformMethod(IntEnum): # Transform topk indices to indices to the page table (page_size = 1) @@ -172,6 +183,15 @@ def get_seqlens_expanded(self) -> torch.Tensor: def get_cu_seqlens_k(self) -> torch.Tensor: return self.attn_metadata.cu_seqlens_k + def get_indexer_kvcache_range(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.attn_metadata.indexer_k_start_end + + def get_indexer_seq_len_cpu(self) -> torch.Tensor: + return self.attn_metadata.indexer_seq_lens_cpu + + def get_token_to_batch_idx(self) -> torch.Tensor: + return self.attn_metadata.token_to_batch_idx + def topk_transform( self, logits: torch.Tensor, @@ -354,6 +374,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Centralized dispatch: decide all strategies for this batch self.set_nsa_prefill_impl(forward_batch) topk_transform_method = self.get_topk_transform_method() + # Batch indices selected when cp enabled: After splitting multiple sequences, + # a certain cp rank may not have some of these sequences. + # We use bs_idx_cpu to mark which sequences are finally selected by the current cp rank, + # a default value of None indicates that all sequences are selected. + bs_idx_cpu = None + # seq_len_cpu of selected sequences + indexer_seq_lens_cpu = forward_batch.seq_lens_cpu if forward_batch.forward_mode.is_decode_or_idle(): extend_seq_lens_cpu = [1] * batch_size @@ -441,7 +468,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table = torch.repeat_interleave( page_table, repeats=forward_batch.extend_seq_lens, dim=0 ) - elif forward_batch.forward_mode.is_extend(): assert ( forward_batch.extend_seq_lens_cpu is not None @@ -450,18 +476,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ), "All of them must not be None" extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu assert forward_batch.extend_seq_lens is not None - - if ( - any(forward_batch.extend_prefix_lens_cpu) - or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND - ): - max_seqlen_q = max(extend_seq_lens_cpu) - cu_seqlens_q = compute_cu_seqlens( - forward_batch.extend_seq_lens.to(torch.int32) - ) - else: - max_seqlen_q = max_seqlen_k - cu_seqlens_q = cu_seqlens_k + extend_seq_lens = forward_batch.extend_seq_lens seqlens_expanded = torch.cat( [ @@ -479,6 +494,36 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ] ) + if can_nsa_prefill_cp_round_robin_split(forward_batch): + seqlens_expanded = nsa_cp_round_robin_split_data(seqlens_expanded) + extend_seq_lens_cpu, extend_seq_lens, bs_idx_cpu, bs_idx = ( + nsa_cp_round_robin_split_q_seqs( + extend_seq_lens_cpu, extend_seq_lens + ) + ) + indexer_seq_lens_cpu = indexer_seq_lens_cpu[bs_idx_cpu] + cache_seqlens_int32 = cache_seqlens_int32[bs_idx] + cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32) + max_seqlen_k = ( + int(indexer_seq_lens_cpu.max().item() + draft_token_num) + if len(indexer_seq_lens_cpu) != 0 + else 0 + ) + page_table = page_table[bs_idx, :max_seqlen_k] + + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + or bs_idx_cpu is not None + ): + max_seqlen_q = ( + max(extend_seq_lens_cpu) if len(extend_seq_lens_cpu) != 0 else 1 + ) + cu_seqlens_q = compute_cu_seqlens(extend_seq_lens.to(torch.int32)) + else: + max_seqlen_q = max_seqlen_k + cu_seqlens_q = cu_seqlens_k + # Check if MHA FP8 dequantization is needed mha_dequantize_needed = ( self.use_mha @@ -496,13 +541,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): [ page_table[i, :kv_len] for i, kv_len in enumerate( - forward_batch.seq_lens_cpu.tolist(), + indexer_seq_lens_cpu.tolist(), ) ] ) - assert ( - page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum - ), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }" + assert page_table_1_flattened.shape[0] == sum( + indexer_seq_lens_cpu + ), f"{page_table_1_flattened.shape[0] = } must be the same as {sum(indexer_seq_lens_cpu) = }" # Validate indices when logical tokens exceed physical capacity # This is likely to be triggered by PP with high kv reuse & parallelism @@ -520,16 +565,22 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if topk_transform_method == TopkTransformMethod.RAGGED: topk_indices_offset = torch.repeat_interleave( cu_seqlens_k[:-1], - forward_batch.extend_seq_lens, + extend_seq_lens, ) else: assert False, f"Unsupported {forward_batch.forward_mode = }" + indexer_k_start_end, token_to_batch_idx = self._cal_indexer_k_start_end( + forward_batch, bs_idx_cpu + ) # 1D, expanded seqlens (1D means cheap to compute, so always compute it) nsa_cache_seqlens_int32 = compute_nsa_seqlens( original_seq_lens=seqlens_expanded, nsa_index_topk=self.nsa_index_topk, ) + nsa_cache_seqlens_int32 = pad_nsa_cache_seqlens( + forward_batch, nsa_cache_seqlens_int32 + ) nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32) nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k)) @@ -586,10 +637,88 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): real_page_table=self._transform_table_1_to_real(page_table), nsa_max_seqlen_q=1, topk_indices_offset=topk_indices_offset, + indexer_k_start_end=indexer_k_start_end, + indexer_seq_lens_cpu=indexer_seq_lens_cpu, + token_to_batch_idx=token_to_batch_idx, ) - self.forward_metadata = metadata + def _cal_indexer_k_start_end( + self, + forward_batch: ForwardBatch, + bs_idx: Optional[List[int]] = None, + ): + if not forward_batch.forward_mode.is_extend_without_speculative(): + return None, None + if forward_batch.batch_size == 0 or (bs_idx is not None and len(bs_idx) == 0): + empty_t = torch.empty(0, dtype=torch.int32, device=self.device) + return (empty_t, empty_t), empty_t + + # Suppose there are two requests, with extend_seq_len = [3, 2] + # and seq_lens = [10, 4] + # The logits matrix looks like this, with * representing the valid logits + # and - representing the invalid logits: + # + # ********--|---- + # *********-|---- + # **********|---- + # ----------|***- + # ----------|**** + # + # ks = [0, 0, 0, 10, 10] + # ke = [8, 9, 10, 13, 14] + ks_list = [] + ke_list = [] + token_to_batch_idx = [] + + q_offset = 0 + k_offset = 0 + + assert ( + forward_batch.seq_lens_cpu is not None + and forward_batch.extend_seq_lens_cpu is not None + ) + for i in range(forward_batch.batch_size): + seq_len = forward_batch.seq_lens_cpu[i].item() + assert isinstance(seq_len, int) + extend_seq_len = forward_batch.extend_seq_lens_cpu[i] + ks = torch.full( + (extend_seq_len,), k_offset, dtype=torch.int32, device=self.device + ) + kv_len = seq_len + if forward_batch.forward_mode.is_target_verify(): + kv_len += self.speculative_num_draft_tokens + seq_lens_expanded = torch.arange( + kv_len - extend_seq_len + 1, + kv_len + 1, + dtype=torch.int32, + device=self.device, + ) + ke = ks + seq_lens_expanded + ks_list.append(ks) + ke_list.append(ke) + + # bi: The index within the selected batch bs_idx. Entries that were not selected are ignored. + bi = bs_idx.index(i) if (bs_idx is not None and i in bs_idx) else i + tb = torch.full( + (extend_seq_len,), bi, dtype=torch.int32, device=self.device + ) + token_to_batch_idx.append(tb) + + if bs_idx is None or i in bs_idx: # skip batch not included in bs_idx + q_offset += extend_seq_len + k_offset += seq_len + + ks = torch.cat(ks_list, dim=0) + ke = torch.cat(ke_list, dim=0) + token_to_batch_idx = torch.cat(token_to_batch_idx, dim=0) + if bs_idx is not None: + assert can_nsa_prefill_cp_round_robin_split(forward_batch) + ks = nsa_cp_round_robin_split_data(ks) + ke = nsa_cp_round_robin_split_data(ke) + token_to_batch_idx = nsa_cp_round_robin_split_data(token_to_batch_idx) + return (ks, ke), token_to_batch_idx + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): """Initialize CUDA graph state for the attention backend. diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4ee1865e0773..15df851ebc40 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -29,6 +29,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.layers.attention.nsa.utils import ( + is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, attn_tp_reduce_scatter_tensor, @@ -95,6 +99,8 @@ class ScatterMode(Enum): @staticmethod def model_input_output(): """The scatter mode for model forward pass input and output data""" + if is_nsa_enable_prefill_cp(): + return ScatterMode.SCATTERED return ScatterMode.TP_ATTN_FULL @@ -330,6 +336,12 @@ def __init__( self.qkv_latent_func = qkv_latent_func self._context = CommunicateContext.init_new() + self._post_init_communicate() + self._speculative_algo = SpeculativeAlgorithm.from_string( + get_global_server_args().speculative_algorithm + ) + + def _post_init_communicate(self): self._communicate_simple_fn = CommunicateSimpleFn.get_fn( input_mode=self.layer_scatter_modes.layer_input_mode, output_mode=self.layer_scatter_modes.attn_mode, @@ -353,10 +365,6 @@ def __init__( ) ) - self._speculative_algo = SpeculativeAlgorithm.from_string( - get_global_server_args().speculative_algorithm - ) - def prepare_attn_and_capture_last_layer_outputs( self, hidden_states: torch.Tensor, @@ -545,6 +553,8 @@ def should_use_reduce_scatter(self, forward_batch: ForwardBatch): and forward_batch.dp_padding_mode.is_max_len() ): return True + if nsa_use_prefill_cp(forward_batch): + return True if get_attn_tp_context().input_scattered and not self.is_last_layer: return True return False diff --git a/python/sglang/srt/layers/communicator_nsa_cp.py b/python/sglang/srt/layers/communicator_nsa_cp.py index 2b323edf1a69..d3f668edbc04 100644 --- a/python/sglang/srt/layers/communicator_nsa_cp.py +++ b/python/sglang/srt/layers/communicator_nsa_cp.py @@ -18,7 +18,10 @@ import torch -from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.layers.attention.nsa.utils import ( + is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, +) from sglang.srt.layers.communicator import ( CommunicateContext, CommunicateSimpleFn, @@ -28,6 +31,11 @@ LayerScatterModes, ScatterMode, ) +from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, + attn_tp_reduce_scatter_tensor, + get_local_dp_buffer, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -57,27 +65,30 @@ def __init__( is_last_layer, qkv_latent_func, ) + + def _post_init_communicate(self): + # SCATTERED in attn tp is different from SCATTERED in global tp when dp_size > 1 + if self.layer_scatter_modes.mlp_mode != ScatterMode.SCATTERED: + assert ( + self._context.attn_dp_size == 1 + ), f"dp_size should be 1 when moe_runner_backend is none" self._communicate_simple_fn = NSACPCommunicateSimpleFn.get_fn( - input_mode=self.layer_scatter_modes.layer_input_mode, - output_mode=self.layer_scatter_modes.attn_mode, + input_mode=ScatterMode.SCATTERED, + output_mode=ScatterMode.SCATTERED, context=self._context, ) - self._communicate_with_all_reduce_and_layer_norm_fn = ( - NSACPCommunicateWithAllReduceAndLayerNormFn.get_fn( - hidden_states_input_mode=self.layer_scatter_modes.attn_mode, - residual_input_mode=self.layer_scatter_modes.layer_input_mode, - hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, - residual_output_mode=self.layer_scatter_modes.middle_residual_mode, - context=self._context, - ) + self._communicate_with_all_reduce_and_layer_norm_fn = NSACPCommunicateWithAllReduceAndLayerNormFn.get_fn( + hidden_states_input_mode=ScatterMode.SCATTERED, + residual_input_mode=ScatterMode.SCATTERED, + hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, # SCATTERED, FULL + residual_output_mode=ScatterMode.SCATTERED, + context=self._context, ) - self._communicate_summable_tensor_pair_fn = ( - NSACPCommunicateSummableTensorPairFn.get_fn( - hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, - residual_input_mode=self.layer_scatter_modes.middle_residual_mode, - output_mode=self.layer_scatter_modes.layer_output_mode, - context=self._context, - ) + self._communicate_summable_tensor_pair_fn = NSACPCommunicateSummableTensorPairFn.get_fn( + hidden_states_input_mode=self.layer_scatter_modes.mlp_mode, # SCATTERED, FULL + residual_input_mode=ScatterMode.SCATTERED, + output_mode=ScatterMode.SCATTERED, + context=self._context, ) @@ -91,25 +102,8 @@ def get_fn( if context.is_same_group_size(input_mode, output_mode): return NSACPCommunicateSimpleFn._trivial - if (input_mode == ScatterMode.SCATTERED) and ( - output_mode == ScatterMode.TP_ATTN_FULL - ): - return NSACPCommunicateSimpleFn._scattered_to_tp_attn_full - raise NotImplementedError(f"{input_mode=} {output_mode=}") - @staticmethod - def _scattered_to_tp_attn_full( - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - context: CommunicateContext, - ) -> torch.Tensor: - - if nsa_enable_prefill_cp(): - return hidden_states - else: - assert False, "Not implemented" - class NSACPCommunicateWithAllReduceAndLayerNormFn( CommunicateWithAllReduceAndLayerNormFn @@ -127,41 +121,18 @@ def get_fn( residual_output_mode: ScatterMode, context: CommunicateContext, ): - if ( - context.is_same_group_size( - hidden_states_input_mode, hidden_states_output_mode - ) - and context.is_same_group_size(residual_input_mode, residual_output_mode) - and context.attn_tp_size == 1 - ): + assert hidden_states_input_mode == ScatterMode.SCATTERED + assert residual_input_mode == ScatterMode.SCATTERED + assert residual_output_mode == ScatterMode.SCATTERED + if hidden_states_output_mode == ScatterMode.SCATTERED: return NSACPCommunicateWithAllReduceAndLayerNormFn._simple - if ( - (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) - and ( - residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL] - ) - and (hidden_states_output_mode == ScatterMode.FULL) - and (residual_output_mode == ScatterMode.TP_ATTN_FULL) - ): + if hidden_states_output_mode == ScatterMode.FULL: return partial( NSACPCommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual, residual_input_mode=residual_input_mode, ) - if ( - (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) - and ( - residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL] - ) - and (hidden_states_output_mode == ScatterMode.SCATTERED) - and (residual_output_mode == ScatterMode.SCATTERED) - ): - return partial( - NSACPCommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual, - residual_input_mode=residual_input_mode, - ) - raise NotImplementedError( f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}" ) @@ -176,30 +147,21 @@ def _gather_hidden_states_and_residual( *, residual_input_mode, ): - if nsa_enable_prefill_cp(): - hidden_states += residual - if hidden_states.shape[0] != 0: - hidden_states = layernorm(hidden_states) - return hidden_states, residual - else: - assert False, "not yet handled" - - @staticmethod - def _scatter_hidden_states_and_residual( - hidden_states: torch.Tensor, - residual: torch.Tensor, - forward_batch: ForwardBatch, - layernorm: torch.nn.Module, - context: CommunicateContext, - *, - residual_input_mode, - ): - if nsa_enable_prefill_cp(): - if hidden_states.shape[0] != 0: - hidden_states, residual = layernorm(hidden_states, residual) - return hidden_states, residual - else: - assert False, "not yet handled" + if hidden_states.shape[0] != 0: + hidden_states, residual = layernorm(hidden_states, residual) + # for prefill: attn tp scattered -> full + # for decode: attn tp full -> full + if nsa_use_prefill_cp(forward_batch): + assert context.attn_dp_size == 1 + hidden_states, local_hidden_states = ( + get_local_dp_buffer(), + hidden_states, + ) + attn_tp_all_gather_into_tensor( + hidden_states, + local_hidden_states, + ) + return hidden_states, residual class NSACPCommunicateSummableTensorPairFn(CommunicateSummableTensorPairFn): @@ -219,24 +181,10 @@ def get_fn( if ( (hidden_states_input_mode == ScatterMode.FULL) - and (residual_input_mode == ScatterMode.TP_ATTN_FULL) - and (output_mode == ScatterMode.TP_ATTN_FULL) - ): - return NSACPCommunicateSummableTensorPairFn._scatter_hidden_states - - if ( - (hidden_states_input_mode == ScatterMode.SCATTERED) and (residual_input_mode == ScatterMode.SCATTERED) - and (output_mode == ScatterMode.TP_ATTN_FULL) - ): - return NSACPCommunicateSummableTensorPairFn._gather - - if ( - (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL) - and (residual_input_mode == ScatterMode.TP_ATTN_FULL) and (output_mode == ScatterMode.SCATTERED) ): - return NSACPCommunicateSummableTensorPairFn._scatter + return NSACPCommunicateSummableTensorPairFn._scatter_hidden_states raise NotImplementedError( f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}" @@ -250,34 +198,13 @@ def _scatter_hidden_states( context: CommunicateContext, allow_reduce_scatter: bool = False, ): - if nsa_enable_prefill_cp(): - return hidden_states, residual - else: - assert False, "not yet handled" - - @staticmethod - def _gather( - hidden_states: torch.Tensor, - residual: torch.Tensor, - forward_batch: ForwardBatch, - context: CommunicateContext, - **kwargs, - ): - hidden_states += residual - residual = None - if nsa_enable_prefill_cp(): - return hidden_states, residual - else: - assert False, "not yet handled" - - @staticmethod - def _scatter( - hidden_states: torch.Tensor, - residual: torch.Tensor, - forward_batch: ForwardBatch, - context: CommunicateContext, - ): - if nsa_enable_prefill_cp(): - return hidden_states, residual - else: - assert False, "not yet handled" + # for prefill: full -> attn tp scattered + # for decode: full -> attn tp full + if nsa_use_prefill_cp(forward_batch): + assert context.attn_dp_size == 1 + input_hidden_states = hidden_states + hidden_states = hidden_states.tensor_split(context.attn_tp_size)[ + context.attn_tp_rank + ] + attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states) + return hidden_states, residual diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index f88c002be557..1dc39ae05dcd 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -24,7 +24,7 @@ import torch -from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_in_seq_split from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -363,7 +363,7 @@ def __init__( self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold ) - self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + self.nsa_prefill_cp_in_seq_split = is_nsa_prefill_cp_in_seq_split() self.prefill_max_requests = prefill_max_requests def _get_running_request_total_token_offset(self, req: Req) -> int: @@ -573,7 +573,7 @@ def add_one_req( # TODO support cp with multiple requests # Enabling context parallelism currently presents precision issues; # therefore, the prefill-batch setting is temporarily set to 1. - if self.nsa_enable_prefill_cp and len(self.can_run_list) >= 1: + if self.nsa_prefill_cp_in_seq_split and len(self.can_run_list) >= 1: return AddReqResult.OTHER if (x := self.prefill_max_requests) is not None and len(self.can_run_list) >= x: diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 5de69a9e3d6f..24a4cd1db253 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -28,8 +28,8 @@ can_cp_split, cp_all_gather_rerange_output, cp_split_and_rebuild_data, - enable_prefill_cp, is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, prepare_input_dp_with_cp_dsa, ) from sglang.srt.layers.dp_attention import ( @@ -160,7 +160,7 @@ def forward( ) ) - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) residual = None with get_global_expert_distribution_recorder().disable_this_region(): @@ -178,7 +178,7 @@ def forward( else: hidden_states = self.shared_head.norm(hidden_states) - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): # allgather + rerrange hidden_states = cp_all_gather_rerange_output( hidden_states, @@ -235,10 +235,9 @@ def forward( ) -> torch.Tensor: # TODO current just support prefill batch=1 and len(input_ids) > self.cp_size * 2 if self.nsa_enable_prefill_cp: - cur_cp_seq_len = len(input_ids) // (self.cp_size * 2) - if can_cp_split(cur_cp_seq_len, self.cp_size, self.use_nsa, forward_batch): + if can_cp_split(len(input_ids), self.cp_size, self.use_nsa, forward_batch): forward_batch.nsa_cp_metadata = prepare_input_dp_with_cp_dsa( - torch.tensor(len(input_ids)), + len(input_ids), self.cp_rank, self.cp_size, forward_batch.seq_lens_cpu.tolist(), diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 33e820785316..2918461d3bb7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -64,8 +64,8 @@ cp_all_gather_rerange_output, cp_split_and_rebuild_data, cp_split_and_rebuild_position, - enable_prefill_cp, is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, prepare_input_dp_with_cp_dsa, ) from sglang.srt.layers.attention.tbo_backend import TboAttnBackend @@ -578,9 +578,7 @@ def forward( if get_global_server_args().enable_deterministic_inference: return F.linear(hidden_states, self.weight, None) - if forward_batch is not None and enable_prefill_cp( - forward_batch, self.nsa_enable_prefill_cp - ): + if forward_batch is not None and nsa_use_prefill_cp(forward_batch): logits = F.linear(hidden_states, self.weight, None) else: # NOTE: For some unknown reason, router_gemm seems degrade accept length. @@ -2006,8 +2004,6 @@ def forward_absorb_prepare( q_nope_out = q_nope_out.transpose(0, 1) - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): - positions = cp_split_and_rebuild_position(forward_batch, positions) if ( self.rotary_emb is not None and (not self._fuse_rope_for_trtllm_mla(forward_batch)) @@ -2015,7 +2011,7 @@ def forward_absorb_prepare( ): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): + if nsa_use_prefill_cp(forward_batch): # support allgather+rerrange k_nope, k_pe = self.rebuild_cp_kv_cache( latent_cache, forward_batch, k_nope, k_pe @@ -3181,8 +3177,10 @@ def forward( hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): - hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + if nsa_use_prefill_cp(forward_batch): + if self.pp_group.is_first_rank: + hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + positions = cp_split_and_rebuild_position(forward_batch, positions) # llama_4_scaling: for supporting Mistral-Large-3 model # Compute llama 4 scaling once per forward pass if enabled @@ -3262,7 +3260,7 @@ def forward( else: hidden_states, _ = self.norm(hidden_states, residual) - if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): + if self.pp_group.is_last_rank and nsa_use_prefill_cp(forward_batch): # allgather + rerrange hidden_states = cp_all_gather_rerange_output( hidden_states, @@ -3400,13 +3398,9 @@ def forward( pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: if self.nsa_enable_prefill_cp: - # TODO current just support prefill batch=1 and len(input_ids) > self.cp_size * 2 - # Note: (self.cp_size * 2) To achieve load balancing for seq computation, - # the seq data needs to be divided and recombined at twice the size of cp_size. - cur_cp_seq_len = len(input_ids) // (self.cp_size * 2) - if can_cp_split(cur_cp_seq_len, self.cp_size, self.use_nsa, forward_batch): + if can_cp_split(len(input_ids), self.cp_size, self.use_nsa, forward_batch): forward_batch.nsa_cp_metadata = prepare_input_dp_with_cp_dsa( - torch.tensor(len(input_ids)), + len(input_ids), self.cp_rank, self.cp_size, forward_batch.seq_lens_cpu.tolist(), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 93609beb25c1..b2a09577b3db 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -149,6 +149,8 @@ RADIX_SUPPORTED_DETERMINISTIC_ATTENTION_BACKEND = ["fa3", "triton"] +NSA_PREFILL_CP_SPLIT_CHOICES = ["in-seq-split", "round-robin-split"] + DEFAULT_LORA_EVICTION_POLICY = "lru" NSA_CHOICES = [ @@ -575,6 +577,7 @@ class ServerArgs: enable_attn_tp_input_scattered: bool = False # Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 enable_nsa_prefill_context_parallel: bool = False + nsa_prefill_cp_mode: str = "in-seq-split" enable_fused_qk_norm_rope: bool = False enable_precise_embedding_interpolation: bool = False @@ -1067,9 +1070,10 @@ def _handle_model_specific_adjustments(self): if self.enable_nsa_prefill_context_parallel: # TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation . self.moe_dense_tp_size = 1 - self.moe_a2a_backend = "deepep" - self.ep_size = self.tp_size - self.kv_cache_dtype = "bf16" + if self.nsa_prefill_cp_mode != "round-robin-split": + self.moe_a2a_backend = "deepep" + self.ep_size = self.tp_size + self.kv_cache_dtype = "bf16" assert ( self.tp_size == 8 ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" @@ -4218,6 +4222,14 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable context parallelism used in the long sequence prefill phase of DeepSeek v3.2.", ) + parser.add_argument( + "--nsa-prefill-cp-mode", + type=str, + default=ServerArgs.nsa_prefill_cp_mode, + choices=NSA_PREFILL_CP_SPLIT_CHOICES, + help="Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: 'in-seq-split' (default), 'round-robin-split'. " + "'round-robin-split' distributes tokens across ranks based on token_idx % cp_size. It supports multi-batch prefill, fused MoE, and FP8 KV cache.", + ) parser.add_argument( "--enable-fused-qk-norm-rope", action="store_true", diff --git a/test/manual/test_deepseek_v32_cp_single_node.py b/test/manual/test_deepseek_v32_cp_single_node.py index c2a6b9d4c824..4fc526ce2944 100644 --- a/test/manual/test_deepseek_v32_cp_single_node.py +++ b/test/manual/test_deepseek_v32_cp_single_node.py @@ -95,5 +95,79 @@ def test_a_gsm8k( self.assertGreater(avg_spec_accept_length, 2.7) +class TestDeepseekV32CPMode1(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = FULL_DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "1", + "--enable-dp-attention", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--mem-frac", + "0.7", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "32", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=200, + parallel=32, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (deepseek-v32 nsa-cp)\n" + f'{metrics["accuracy"]=:.3f}\n' + f"{avg_spec_accept_length=:.2f}\n" + ) + self.assertGreater(metrics["accuracy"], 0.935) + self.assertGreater(avg_spec_accept_length, 2.7) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_prefill_adder.py b/test/srt/test_prefill_adder.py index 60a8b904da70..01772e526d15 100644 --- a/test/srt/test_prefill_adder.py +++ b/test/srt/test_prefill_adder.py @@ -12,7 +12,7 @@ def setUp(self): self.mock_tree_cache = self.create_tree_cache() self.mock_token_allocator = self.create_token_allocator() patcher = patch( - "sglang.srt.managers.schedule_policy.is_nsa_enable_prefill_cp", + "sglang.srt.managers.schedule_policy.is_nsa_prefill_cp_in_seq_split", return_value=False, ) self.mock_is_nsa = patcher.start()