Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/basic_usage/deepseek_v32.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ pip3 install -e "python"
To serve DeepSeek-V3.2-Exp on 8xH200/B200 GPUs:

```bash
# Launch with TP + DP
# Launch with TP + DP (Recommended)
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention

# Launch with EP + DP
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 8 --enable-dp-attention

# Launch with Pure TP
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8
```

### Configuration Tips
- **DP Attention**: For DeepSeek V3.2 model, the kernels are customized for the use case of `dp_size=8`, so DP attention is enabled by default for better stability and performance. The feature of launching with pure TP is still under development.
- **DP Attention (Recommended)**: For DeepSeek V3.2 model, the kernels are customized for the use case of `dp_size=8`, so DP attention (`--dp 8 --enable-dp-attention`) is the recommended configuration for better stability and performance. All test cases use this configuration by default.
- **Pure TP Mode**: Launching with pure TP (without `--dp` and `--enable-dp-attention`) is also supported. Note that this mode has not been fully validated in PD disaggregation scenarios.
- **Short-sequence MHA prefill (adaptive)**: For short prefill sequences (default threshold: **2048 tokens**), the NSA backend uses standard MHA automatically (no extra flags). On H200 (SM90) this path uses the FlashAttention variable-length kernel; on B200 (SM100) it uses TRT-LLM ragged MHA. MHA uses `MHA_ONE_SHOT` for best performance. `MHA_ONE_SHOT` computes multi-head attention over all tokens (both cached prefix and newly extended tokens) in a single kernel invocation, avoiding the overhead of chunked KV cache processing. This achieves optimal throughput for short sequences where total sequence length fits within the chunk capacity limit.
- **Choices of Attention Kernels**: The attention backend is automatically set to `nsa` attention backend for DeepSeek V3.2 model. In this backend, different kernels for sparse prefilling/decoding are implemented, which can be specified by `--nsa-prefill-backend` and `--nsa-decode-backend` server arguments. The choices of nsa prefill/decode attention kernels include:
- `flashmla_sparse`: `flash_mla_sparse_fwd` kernel from `flash_mla` library. Can run on both Hopper and Blackwell GPUs. It requires bf16 q, kv inputs.
Expand Down
109 changes: 95 additions & 14 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,25 @@ def _get_topk_paged(
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result

def _should_chunk_mqa_logits(
self, num_q: int, num_k: int, device: torch.device
) -> Tuple[bool, int]:
"""
Detect whether we need to chunk the MQA logits computation to avoid OOM
Return: (need_chunk, free_mem)
"""
# Quick static check for normal batches
if num_q * num_k < 8_000_000: # 8M elements ≈ 32MB logits
return False, 0

free_mem, total_mem = torch.cuda.mem_get_info(device)
bytes_per_elem = 4 # float32
logits_bytes = num_q * num_k * bytes_per_elem

# Logits should not exceed 50% of free memory or 30% of total memory
need_chunk = (logits_bytes * 2 > free_mem) or (logits_bytes > total_mem * 0.3)
return need_chunk, free_mem

def _get_topk_ragged(
self,
forward_batch: ForwardBatch,
Expand Down Expand Up @@ -409,24 +428,86 @@ def _get_topk_ragged(
# ks = [0, 0, 0, 10, 10]
# ke = [8, 9, 10, 13, 14]

logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
weights[:q_offset],
ks,
ke,
clean_logits=False,
)

token_nums, _, _ = q_fp8.shape
assert logits.shape[0] == len(seq_lens_expanded)
assert logits.shape[1] == k_offset
device = q_fp8.device

# Check if we need to chunk to avoid OOM
need_chunk, free_mem = self._should_chunk_mqa_logits(q_offset, k_offset, device)

if not need_chunk:
logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
weights[:q_offset],
ks,
ke,
clean_logits=False,
)
assert logits.shape[0] == len(seq_lens_expanded)
assert logits.shape[1] == k_offset

raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks)
topk_result = torch.full(
(token_nums, self.index_topk),
-1,
device=device,
dtype=torch.int32,
)
topk_result[:q_offset] = raw_topk_result
return topk_result

# Chunk path
bytes_per_elem = 4 # float32
bytes_per_row = k_offset * bytes_per_elem
# Reserve 50% of free memory for logits
max_rows = max(1, int((free_mem * 0.5) // max(bytes_per_row, 1)))
max_rows = min(max_rows, q_offset)

global_topk_offset = metadata.attn_metadata.topk_indices_offset

assert (
seq_lens_expanded.shape[0] == q_offset
), f"seq_lens_expanded length mismatch: {seq_lens_expanded.shape[0]} != {q_offset}"
if global_topk_offset is not None:
assert (
global_topk_offset.shape[0] >= q_offset
), f"topk_indices_offset too short: {global_topk_offset.shape[0]} < {q_offset}"

raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks)
topk_result = torch.full(
(token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
(token_nums, self.index_topk), -1, device=device, dtype=torch.int32
)
topk_result[:q_offset] = raw_topk_result

start = 0
while start < q_offset:
end = min(start + max_rows, q_offset)

logits_chunk = deep_gemm.fp8_mqa_logits(
q_fp8[start:end],
kv_fp8,
weights[start:end],
ks[start:end],
ke[start:end],
clean_logits=False,
)

lengths_chunk = seq_lens_expanded[start:end]

topk_offset_chunk = (
global_topk_offset[start:end]
if global_topk_offset is not None
else None
)

raw_topk_chunk = metadata.topk_transform(
logits_chunk,
self.index_topk,
ks=ks[start:end],
ke_offset=lengths_chunk,
topk_indices_offset_override=topk_offset_chunk,
)
topk_result[start:end] = raw_topk_chunk
start = end

return topk_result

def _forward_cuda_k_only(
Expand Down
79 changes: 72 additions & 7 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,18 @@ def topk_transform(
cu_seqlens_q: torch.Tensor = None,
ke_offset: torch.Tensor = None,
batch_idx_list: List[int] = None,
topk_indices_offset_override: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sgl_kernel import (
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)

if cu_seqlens_q is not None:
if topk_indices_offset_override is not None:
cu_topk_indices_offset = topk_indices_offset_override
cu_seqlens_q_topk = None
elif cu_seqlens_q is not None:
cu_seqlens_q = cu_seqlens_q.to(torch.int32)
cu_seqlens_q_topk = compute_cu_seqlens(cu_seqlens_q)
cu_topk_indices_offset = torch.repeat_interleave(
Expand Down Expand Up @@ -286,9 +290,11 @@ def __init__(
)
self.speculative_step_id = speculative_step_id

self.device_capability = torch.cuda.get_device_capability()
self.device_sm_major = self.device_capability[0]

# Allocate global workspace buffer for TRTLLm ragged attention kernel (SM100/B200)
device_sm_major = torch.cuda.get_device_capability()[0]
if device_sm_major >= 10:
if self.device_sm_major >= 10:
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
Expand Down Expand Up @@ -921,6 +927,11 @@ def forward_extend(
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]

# Align topk_indices with q dimensions
# This handles cases where q is padded (TP + partial DP attention)
if topk_indices is not None:
topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0])

# NOTE(dark): here, we use page size = 1
topk_transform_method = self.get_topk_transform_method()
if NSA_FUSE_TOPK:
Expand Down Expand Up @@ -1058,6 +1069,10 @@ def forward_decode(
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]

# Align topk_indices with q dimensions
if topk_indices is not None:
topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0])

if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
Expand Down Expand Up @@ -1178,13 +1193,43 @@ def _forward_flashmla_sparse(
) -> torch.Tensor:
from sgl_kernel.flash_mla import flash_mla_sparse_fwd

# FlashMLA sparse kernel requires num_heads to be a multiple of 64 (Hopper) or 128 (Blackwell)
# When using TP, num_heads might be smaller (e.g., 256//8=32)
num_tokens, num_heads, head_dim = q_all.shape

# Determine required padding based on GPU architecture (use cached value)
required_padding = 128 if self.device_sm_major >= 10 else 64

need_padding = num_heads % required_padding != 0

if need_padding:
assert required_padding % num_heads == 0, (
f"num_heads {num_heads} cannot be padded to {required_padding}. "
f"TP size may be too large for this model."
)

# Pad q to required size
q_padded = q_all.new_zeros((num_tokens, required_padding, head_dim))
q_padded[:, :num_heads, :] = q_all
q_input = q_padded
else:
q_input = q_all

# indices shape must be (s_q, h_kv=1, topk), keep h_kv=1 unchanged
indices_input = page_table_1.unsqueeze(1)

o, _, _ = flash_mla_sparse_fwd(
q=q_all,
q=q_input,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
indices=indices_input,
sm_scale=sm_scale,
d_v=v_head_dim,
)

# Trim output back to original num_heads if we padded
if need_padding:
o = o[:, :num_heads, :]

return o

def _forward_flashmla_kv(
Expand Down Expand Up @@ -1259,8 +1304,7 @@ def _forward_standard_mha(
)

# Use TRTLLm ragged attention for SM100 (Blackwell/B200) to avoid FA4 accuracy issues
device_sm_major = torch.cuda.get_device_capability()[0]
if device_sm_major >= 10:
if self.device_sm_major >= 10:
import flashinfer

seq_lens = metadata.cache_seqlens_int32
Expand Down Expand Up @@ -1357,6 +1401,27 @@ def _forward_aiter(
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
return o

def _pad_topk_indices(
self, topk_indices: torch.Tensor, num_tokens: int
) -> torch.Tensor:
current_tokens = topk_indices.shape[0]
if current_tokens == num_tokens:
return topk_indices

assert current_tokens <= num_tokens, (
f"topk_indices rows ({current_tokens}) > num_tokens ({num_tokens}); "
"this indicates a mismatch between indexer output and q layout."
)

pad_size = num_tokens - current_tokens
padding = torch.full(
(pad_size, topk_indices.shape[1]),
-1,
dtype=topk_indices.dtype,
device=topk_indices.device,
)
return torch.cat([topk_indices, padding], dim=0)

def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,12 @@ def _handle_model_specific_adjustments(self):
f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} "
)
else:
self.dp_size = self.tp_size
# Pure TP and partial DP Attention mode is active for NSA, logging a warning
if self.dp_size < self.tp_size:
logger.warning(
f"NSA with TP mode is active, dp_size={self.dp_size}, tp_size={self.tp_size}, "
f"attn_tp_size={self.tp_size}, attention weights will be sharded across {self.tp_size} ranks."
)

self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
Expand Down
25 changes: 25 additions & 0 deletions test/manual/nightly/test_deepseek_v32_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def setUpClass(cls):
"--trust-remote-code",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
],
Expand All @@ -35,6 +38,9 @@ def setUpClass(cls):
"--trust-remote-code",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
Expand All @@ -51,6 +57,25 @@ def setUpClass(cls):
},
{
"name": "nsa",
"other_args": [
"--trust-remote-code",
"--tp",
"8",
"--dp",
"8",
"--enable-dp-attention",
"--attention-backend",
"nsa",
"--nsa-prefill-backend",
"flashmla_sparse",
"--nsa-decode-backend",
"flashmla_kv",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
],
},
{
"name": "pure_tp",
"other_args": [
"--trust-remote-code",
"--tp",
Expand Down
Loading
Loading