Skip to content
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-layerwise-nvtx-marker` | Enable layerwise NVTX profiling annotations for the model. This adds NVTX markers to every layer for detailed per-layer performance analysis with Nsight Systems. | `False` | bool flag (set to enable) |
| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |
| `--enable-nsa-prefill-context-parallel` | Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 | `False` | bool flag (set to enable) |
| `--nsa-prefill-cp-mode` | Token splitting mode for the prefill phase of DeepSeek v3.2 under context parallelism. Optional values: `in-seq-split` (default), `round-robin-split`. `round-robin-split` distributes tokens across ranks based on `token_idx % cp_size`. It supports multi-batch prefill, fused MoE, and FP8 KV cache. | `in-seq-split` | Type: str |

## Forward hooks
| Argument | Description | Defaults | Options |
Expand Down
12 changes: 12 additions & 0 deletions docs/basic_usage/deepseek_v32.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,15 @@ Some features are still not supported at present.
- **Other Args**: Currently only supports moe_dense_tp_size=1, kv_cache_dtype = "bf16", moe_a2a_backend = "deepep",
- **DP_size**: `CP_size` reuses `atten_tp_size`, which is equal to `TP_size` / `DP_size`. For the cp function to work correctly, `TP_size` must be divisible by `DP_size`, and TP_size / DP_size > 1 (to ensure CP_size > 1).
- **Detailed design reference**: https://github.com/sgl-project/sglang/pull/12065

Comment thread
xu-yfei marked this conversation as resolved.
### Alternative context parallel mode

You can switch the CP token splitting mode for prefill by specifying the parameter `--nsa-prefill-cp-mode round-robin-split`. It distributes tokens across ranks based on `token_idx % cp_size`.
In this scenario, compared with the aforementioned method, it additionally supports the fused MoE backend (the fused MoE backend may deliver better performance than DeepEP in single-machine scenarios),
FP8 KV-cache, and multi-batch prefill inference. For more details, please refer to PR https://github.com/sgl-project/sglang/pull/13959.

Example usage:
```bash
# Launch with FusedMoe + CP8 + DP1
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 1 --enable-dp-attention --enable-nsa-prefill-context-parallel --nsa-prefill-cp-mode round-robin-split --max-running-requests 32
```
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from sglang.srt.layers.attention.nsa.utils import (
cp_split_and_rebuild_position,
enable_prefill_cp,
nsa_use_prefill_cp,
)
from sglang.srt.layers.communicator import get_attn_tp_context

Expand Down Expand Up @@ -192,12 +192,12 @@ def forward_mla_prepare_npu(

q_nope_out = q_nope_out.transpose(0, 1)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
positions = cp_split_and_rebuild_position(forward_batch, positions)

q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
# support allgather+rerrange
k_nope, k_pe = m.rebuild_cp_kv_cache(
latent_cache, forward_batch, k_nope, k_pe
Expand Down Expand Up @@ -338,12 +338,12 @@ def forward_dsa_prepare_npu(

q_nope_out = q_nope_out.transpose(0, 1)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
positions = cp_split_and_rebuild_position(forward_batch, positions)

q_pe, k_pe = m.rotary_emb(positions, q_pe, k_pe)

if enable_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
if nsa_use_prefill_cp(forward_batch, m.nsa_enable_prefill_cp):
# support allgather+rerrange
k_nope, k_pe = m.rebuild_cp_kv_cache(
latent_cache, forward_batch, k_nope, k_pe
Expand Down
113 changes: 45 additions & 68 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
NSA_DUAL_STREAM,
cp_all_gather_rerange_output,
is_nsa_enable_prefill_cp,
is_nsa_prefill_cp_in_seq_split,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.linear import ReplicatedLinear
Expand Down Expand Up @@ -63,6 +64,21 @@ def get_seqlens_expanded(self) -> torch.Tensor:
Return: (sum_extend_seq_len,) int32 tensor
"""

def get_indexer_kvcache_range(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return: (tokens, ), (tokens, ) int32, k_start and k_end in kv cache(token,xxx) for each token.
"""

def get_indexer_seq_len_cpu(self) -> torch.Tensor:
"""
Return: seq lens for each batch.
"""

def get_token_to_batch_idx(self) -> torch.Tensor:
"""
Return: batch idx for each token.
"""

@abstractmethod
def topk_transform(
self,
Expand Down Expand Up @@ -227,15 +243,6 @@ def _get_q_k_bf16(
query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope

# allgather+rerrange
if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp:
key = cp_all_gather_rerange_output(
key.contiguous(),
self.cp_size,
forward_batch,
torch.cuda.current_stream(),
)

if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
Expand All @@ -248,6 +255,14 @@ def _get_q_k_bf16(
query = rotate_activation(query)
key = rotate_activation(key)

# allgather+rerrange
Comment thread
Fridge003 marked this conversation as resolved.
if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp:
key = cp_all_gather_rerange_output(
key.contiguous(),
self.cp_size,
forward_batch,
torch.cuda.current_stream(),
)
return query, key

def _get_k_bf16(
Expand Down Expand Up @@ -373,72 +388,51 @@ def _get_topk_ragged(
weights = weights.squeeze(-1)
k_fp8_list = []
k_scale_list = []
ks_list = []
ke_list = []
# Token-to-batch mapping for PAGED chunk alignment
token_to_batch_idx: List[int] = []

q_offset = 0
k_offset = 0

seq_lens_expanded = metadata.get_seqlens_expanded()
block_tables = metadata.get_page_table_64()

assert (
forward_batch.seq_lens_cpu is not None
and forward_batch.extend_seq_lens_cpu is not None
)

for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
batch_size = len(block_tables)
token_nums, _, _ = q_fp8.shape
device = q_fp8.device
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=device, dtype=torch.int32
)
if batch_size == 0:
return topk_result

indexer_seq_lens_cpu = metadata.get_indexer_seq_len_cpu()
assert len(indexer_seq_lens_cpu) == batch_size
for i in range(batch_size):
seq_len = indexer_seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
# Use fused Triton kernel to get both K and scale in a single call
k_fp8, k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_buffer(
layer_id,
seq_len,
block_tables[i],
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full(
(extend_seq_len,), k_offset, dtype=torch.int32, device="cuda"
)
ke = ks + seq_lens_expanded[q_offset : q_offset + extend_seq_len]
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
ke_list.append(ke)

token_to_batch_idx.extend([i] * extend_seq_len)
q_offset += extend_seq_len
k_offset += seq_len

k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
ke = torch.cat(ke_list, dim=0)

# Suppose there are two requests, with extend_seq_len = [3, 2]
# and seq_lens = [10, 4]
# The logits matrix looks like this, with * representing the valid logits
# and - representing the invalid logits:
#
# ********--|----
# *********-|----
# **********|----
# ----------|***-
# ----------|****
#
# ks = [0, 0, 0, 10, 10]
# ke = [8, 9, 10, 13, 14]

token_nums, _, _ = q_fp8.shape
device = q_fp8.device
ks, ke = metadata.get_indexer_kvcache_range()
seq_lens_expanded = metadata.get_seqlens_expanded()
token_to_batch_idx = metadata.get_token_to_batch_idx()
q_offset = ks.shape[0]
k_offset = k_fp8.shape[0]

# Check if we need to chunk to avoid OOM
need_chunk, free_mem = self._should_chunk_mqa_logits(q_offset, k_offset, device)

if not need_chunk:
assert q_fp8[:q_offset].shape[0] != 0
logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
Expand All @@ -451,12 +445,6 @@ def _get_topk_ragged(
assert logits.shape[1] == k_offset

raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks)
topk_result = torch.full(
(token_nums, self.index_topk),
-1,
device=device,
dtype=torch.int32,
)
topk_result[:q_offset] = raw_topk_result
return topk_result

Expand All @@ -477,17 +465,6 @@ def _get_topk_ragged(
global_topk_offset.shape[0] >= q_offset
), f"topk_indices_offset too short: {global_topk_offset.shape[0]} < {q_offset}"

topk_result = torch.full(
(token_nums, self.index_topk), -1, device=device, dtype=torch.int32
)

# Only materialize batch index tensor when PAGED path needs it
token_to_batch_idx_tensor = None
if global_topk_offset is None:
token_to_batch_idx_tensor = torch.tensor(
token_to_batch_idx, dtype=torch.long, device=device
)

start = 0
while start < q_offset:
end = min(start + max_rows, q_offset)
Expand Down Expand Up @@ -516,7 +493,7 @@ def _get_topk_ragged(
cu_seqlens_q_chunk = torch.ones(
B_chunk, dtype=torch.int32, device=device
)
batch_idx_chunk = token_to_batch_idx_tensor[start:end]
batch_idx_chunk = token_to_batch_idx[start:end]
Comment thread
Fridge003 marked this conversation as resolved.

raw_topk_chunk = metadata.topk_transform(
logits_chunk,
Expand Down Expand Up @@ -911,7 +888,7 @@ def forward_cuda(
else:
if (
forward_batch.nsa_cp_metadata is not None
and self.nsa_enable_prefill_cp
and is_nsa_prefill_cp_in_seq_split()
):
kv_len_prev = forward_batch.nsa_cp_metadata.kv_len_prev
kv_len_next = forward_batch.nsa_cp_metadata.kv_len_next
Expand Down
Loading
Loading