diff --git a/docs/basic_usage/deepseek_v32.md b/docs/basic_usage/deepseek_v32.md index caad4c8758ab..2ef73fad72e2 100644 --- a/docs/basic_usage/deepseek_v32.md +++ b/docs/basic_usage/deepseek_v32.md @@ -34,15 +34,19 @@ pip3 install -e "python" To serve DeepSeek-V3.2-Exp on 8xH200/B200 GPUs: ```bash -# Launch with TP + DP +# Launch with TP + DP (Recommended) python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention # Launch with EP + DP python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 8 --enable-dp-attention + +# Launch with Pure TP +python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 ``` ### Configuration Tips -- **DP Attention**: For DeepSeek V3.2 model, the kernels are customized for the use case of `dp_size=8`, so DP attention is enabled by default for better stability and performance. The feature of launching with pure TP is still under development. +- **DP Attention (Recommended)**: For DeepSeek V3.2 model, the kernels are customized for the use case of `dp_size=8`, so DP attention (`--dp 8 --enable-dp-attention`) is the recommended configuration for better stability and performance. All test cases use this configuration by default. +- **Pure TP Mode**: Launching with pure TP (without `--dp` and `--enable-dp-attention`) is also supported. Note that this mode has not been fully validated in PD disaggregation scenarios. - **Short-sequence MHA prefill (adaptive)**: For short prefill sequences (default threshold: **2048 tokens**), the NSA backend uses standard MHA automatically (no extra flags). On H200 (SM90) this path uses the FlashAttention variable-length kernel; on B200 (SM100) it uses TRT-LLM ragged MHA. MHA uses `MHA_ONE_SHOT` for best performance. `MHA_ONE_SHOT` computes multi-head attention over all tokens (both cached prefix and newly extended tokens) in a single kernel invocation, avoiding the overhead of chunked KV cache processing. This achieves optimal throughput for short sequences where total sequence length fits within the chunk capacity limit. - **Choices of Attention Kernels**: The attention backend is automatically set to `nsa` attention backend for DeepSeek V3.2 model. In this backend, different kernels for sparse prefilling/decoding are implemented, which can be specified by `--nsa-prefill-backend` and `--nsa-decode-backend` server arguments. The choices of nsa prefill/decode attention kernels include: - `flashmla_sparse`: `flash_mla_sparse_fwd` kernel from `flash_mla` library. Can run on both Hopper and Blackwell GPUs. It requires bf16 q, kv inputs. diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index f5db3d7a3068..55aaadca2a1b 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -330,6 +330,25 @@ def _get_topk_paged( topk_result = metadata.topk_transform(logits, self.index_topk) return topk_result + def _should_chunk_mqa_logits( + self, num_q: int, num_k: int, device: torch.device + ) -> Tuple[bool, int]: + """ + Detect whether we need to chunk the MQA logits computation to avoid OOM + Return: (need_chunk, free_mem) + """ + # Quick static check for normal batches + if num_q * num_k < 8_000_000: # 8M elements ≈ 32MB logits + return False, 0 + + free_mem, total_mem = torch.cuda.mem_get_info(device) + bytes_per_elem = 4 # float32 + logits_bytes = num_q * num_k * bytes_per_elem + + # Logits should not exceed 50% of free memory or 30% of total memory + need_chunk = (logits_bytes * 2 > free_mem) or (logits_bytes > total_mem * 0.3) + return need_chunk, free_mem + def _get_topk_ragged( self, forward_batch: ForwardBatch, @@ -409,24 +428,86 @@ def _get_topk_ragged( # ks = [0, 0, 0, 10, 10] # ke = [8, 9, 10, 13, 14] - logits = deep_gemm.fp8_mqa_logits( - q_fp8[:q_offset], - kv_fp8, - weights[:q_offset], - ks, - ke, - clean_logits=False, - ) - token_nums, _, _ = q_fp8.shape - assert logits.shape[0] == len(seq_lens_expanded) - assert logits.shape[1] == k_offset + device = q_fp8.device + + # Check if we need to chunk to avoid OOM + need_chunk, free_mem = self._should_chunk_mqa_logits(q_offset, k_offset, device) + + if not need_chunk: + logits = deep_gemm.fp8_mqa_logits( + q_fp8[:q_offset], + kv_fp8, + weights[:q_offset], + ks, + ke, + clean_logits=False, + ) + assert logits.shape[0] == len(seq_lens_expanded) + assert logits.shape[1] == k_offset + + raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks) + topk_result = torch.full( + (token_nums, self.index_topk), + -1, + device=device, + dtype=torch.int32, + ) + topk_result[:q_offset] = raw_topk_result + return topk_result + + # Chunk path + bytes_per_elem = 4 # float32 + bytes_per_row = k_offset * bytes_per_elem + # Reserve 50% of free memory for logits + max_rows = max(1, int((free_mem * 0.5) // max(bytes_per_row, 1))) + max_rows = min(max_rows, q_offset) + + global_topk_offset = metadata.attn_metadata.topk_indices_offset + + assert ( + seq_lens_expanded.shape[0] == q_offset + ), f"seq_lens_expanded length mismatch: {seq_lens_expanded.shape[0]} != {q_offset}" + if global_topk_offset is not None: + assert ( + global_topk_offset.shape[0] >= q_offset + ), f"topk_indices_offset too short: {global_topk_offset.shape[0]} < {q_offset}" - raw_topk_result = metadata.topk_transform(logits, self.index_topk, ks=ks) topk_result = torch.full( - (token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32 + (token_nums, self.index_topk), -1, device=device, dtype=torch.int32 ) - topk_result[:q_offset] = raw_topk_result + + start = 0 + while start < q_offset: + end = min(start + max_rows, q_offset) + + logits_chunk = deep_gemm.fp8_mqa_logits( + q_fp8[start:end], + kv_fp8, + weights[start:end], + ks[start:end], + ke[start:end], + clean_logits=False, + ) + + lengths_chunk = seq_lens_expanded[start:end] + + topk_offset_chunk = ( + global_topk_offset[start:end] + if global_topk_offset is not None + else None + ) + + raw_topk_chunk = metadata.topk_transform( + logits_chunk, + self.index_topk, + ks=ks[start:end], + ke_offset=lengths_chunk, + topk_indices_offset_override=topk_offset_chunk, + ) + topk_result[start:end] = raw_topk_chunk + start = end + return topk_result def _forward_cuda_k_only( diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 66fce2bc3d08..5aa853010e2a 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -170,6 +170,7 @@ def topk_transform( cu_seqlens_q: torch.Tensor = None, ke_offset: torch.Tensor = None, batch_idx_list: List[int] = None, + topk_indices_offset_override: Optional[torch.Tensor] = None, ) -> torch.Tensor: from sgl_kernel import ( fast_topk_transform_fused, @@ -177,7 +178,10 @@ def topk_transform( fast_topk_v2, ) - if cu_seqlens_q is not None: + if topk_indices_offset_override is not None: + cu_topk_indices_offset = topk_indices_offset_override + cu_seqlens_q_topk = None + elif cu_seqlens_q is not None: cu_seqlens_q = cu_seqlens_q.to(torch.int32) cu_seqlens_q_topk = compute_cu_seqlens(cu_seqlens_q) cu_topk_indices_offset = torch.repeat_interleave( @@ -286,9 +290,11 @@ def __init__( ) self.speculative_step_id = speculative_step_id + self.device_capability = torch.cuda.get_device_capability() + self.device_sm_major = self.device_capability[0] + # Allocate global workspace buffer for TRTLLm ragged attention kernel (SM100/B200) - device_sm_major = torch.cuda.get_device_capability()[0] - if device_sm_major >= 10: + if self.device_sm_major >= 10: global global_workspace_buffer if global_workspace_buffer is None: global_workspace_buffer = torch.empty( @@ -921,6 +927,11 @@ def forward_extend( q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] + # Align topk_indices with q dimensions + # This handles cases where q is padded (TP + partial DP attention) + if topk_indices is not None: + topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0]) + # NOTE(dark): here, we use page size = 1 topk_transform_method = self.get_topk_transform_method() if NSA_FUSE_TOPK: @@ -1058,6 +1069,10 @@ def forward_decode( q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] + # Align topk_indices with q dimensions + if topk_indices is not None: + topk_indices = self._pad_topk_indices(topk_indices, q_nope.shape[0]) + if NSA_FUSE_TOPK: page_table_1 = topk_indices else: @@ -1178,13 +1193,43 @@ def _forward_flashmla_sparse( ) -> torch.Tensor: from sgl_kernel.flash_mla import flash_mla_sparse_fwd + # FlashMLA sparse kernel requires num_heads to be a multiple of 64 (Hopper) or 128 (Blackwell) + # When using TP, num_heads might be smaller (e.g., 256//8=32) + num_tokens, num_heads, head_dim = q_all.shape + + # Determine required padding based on GPU architecture (use cached value) + required_padding = 128 if self.device_sm_major >= 10 else 64 + + need_padding = num_heads % required_padding != 0 + + if need_padding: + assert required_padding % num_heads == 0, ( + f"num_heads {num_heads} cannot be padded to {required_padding}. " + f"TP size may be too large for this model." + ) + + # Pad q to required size + q_padded = q_all.new_zeros((num_tokens, required_padding, head_dim)) + q_padded[:, :num_heads, :] = q_all + q_input = q_padded + else: + q_input = q_all + + # indices shape must be (s_q, h_kv=1, topk), keep h_kv=1 unchanged + indices_input = page_table_1.unsqueeze(1) + o, _, _ = flash_mla_sparse_fwd( - q=q_all, + q=q_input, kv=kv_cache, - indices=page_table_1.unsqueeze(1), + indices=indices_input, sm_scale=sm_scale, d_v=v_head_dim, ) + + # Trim output back to original num_heads if we padded + if need_padding: + o = o[:, :num_heads, :] + return o def _forward_flashmla_kv( @@ -1259,8 +1304,7 @@ def _forward_standard_mha( ) # Use TRTLLm ragged attention for SM100 (Blackwell/B200) to avoid FA4 accuracy issues - device_sm_major = torch.cuda.get_device_capability()[0] - if device_sm_major >= 10: + if self.device_sm_major >= 10: import flashinfer seq_lens = metadata.cache_seqlens_int32 @@ -1357,6 +1401,27 @@ def _forward_aiter( # kv_cache = kv_cache.view(-1, 1, layer.head_dim) return o + def _pad_topk_indices( + self, topk_indices: torch.Tensor, num_tokens: int + ) -> torch.Tensor: + current_tokens = topk_indices.shape[0] + if current_tokens == num_tokens: + return topk_indices + + assert current_tokens <= num_tokens, ( + f"topk_indices rows ({current_tokens}) > num_tokens ({num_tokens}); " + "this indicates a mismatch between indexer output and q layout." + ) + + pad_size = num_tokens - current_tokens + padding = torch.full( + (pad_size, topk_indices.shape[1]), + -1, + dtype=topk_indices.dtype, + device=topk_indices.device, + ) + return torch.cat([topk_indices, padding], dim=0) + def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" return 1 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7a718795a449..18c073a897cd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -963,7 +963,12 @@ def _handle_model_specific_adjustments(self): f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} " ) else: - self.dp_size = self.tp_size + # Pure TP and partial DP Attention mode is active for NSA, logging a warning + if self.dp_size < self.tp_size: + logger.warning( + f"NSA with TP mode is active, dp_size={self.dp_size}, tp_size={self.tp_size}, " + f"attn_tp_size={self.tp_size}, attention weights will be sharded across {self.tp_size} ranks." + ) self.page_size = 64 logger.warning("Setting page size to 64 for DeepSeek NSA.") diff --git a/test/manual/nightly/test_deepseek_v32_perf.py b/test/manual/nightly/test_deepseek_v32_perf.py index f7ed778c0723..d395fdd0aa8e 100644 --- a/test/manual/nightly/test_deepseek_v32_perf.py +++ b/test/manual/nightly/test_deepseek_v32_perf.py @@ -25,6 +25,9 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "8", + "--dp", + "8", + "--enable-dp-attention", "--model-loader-extra-config", '{"enable_multithread_load": true}', ], @@ -35,6 +38,9 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "8", + "--dp", + "8", + "--enable-dp-attention", "--speculative-algorithm", "EAGLE", "--speculative-num-steps", @@ -51,6 +57,25 @@ def setUpClass(cls): }, { "name": "nsa", + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--attention-backend", + "nsa", + "--nsa-prefill-backend", + "flashmla_sparse", + "--nsa-decode-backend", + "flashmla_kv", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + ], + }, + { + "name": "pure_tp", "other_args": [ "--trust-remote-code", "--tp", diff --git a/test/nightly/test_deepseek_v32_nsabackend.py b/test/nightly/test_deepseek_v32_nsabackend.py index 45e2d665aadf..3896c93b70a5 100644 --- a/test/nightly/test_deepseek_v32_nsabackend.py +++ b/test/nightly/test_deepseek_v32_nsabackend.py @@ -197,6 +197,63 @@ def test_a_gsm8k( self.assertGreater(metrics["accuracy"], 0.935) +class TestDeepseekV32NasBackend_pure_tp(CustomTestCase): + """Test DeepSeek V3.2 with pure TP mode (no DP attention).""" + + @classmethod + def setUpClass(cls): + cls.model = DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + # Pure TP configuration without --dp and --enable-dp-attention + other_args = [ + "--trust-remote-code", + "--attention-backend", + "nsa", + "--nsa-prefill-backend", + "flashmla_sparse", + "--nsa-decode-backend", + "flashmla_kv", + "--tp", + "8", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k(self): + """Test GSM8K accuracy with pure TP mode.""" + args = SimpleNamespace( + num_shots=20, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + TEST_RESULTS.append( + { + "variant": "pure_tp", + "prefill_backend": "flashmla_sparse", + "decode_backend": "flashmla_kv", + "kv_cache": "fp16", + "accuracy": metrics["accuracy"], + } + ) + self.assertGreater(metrics["accuracy"], 0.935) + + def _write_summary_table(): """Write a markdown table with all test results.""" if not TEST_RESULTS: diff --git a/test/nightly/test_deepseek_v32_perf.py b/test/nightly/test_deepseek_v32_perf.py index f7ed778c0723..d395fdd0aa8e 100644 --- a/test/nightly/test_deepseek_v32_perf.py +++ b/test/nightly/test_deepseek_v32_perf.py @@ -25,6 +25,9 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "8", + "--dp", + "8", + "--enable-dp-attention", "--model-loader-extra-config", '{"enable_multithread_load": true}', ], @@ -35,6 +38,9 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "8", + "--dp", + "8", + "--enable-dp-attention", "--speculative-algorithm", "EAGLE", "--speculative-num-steps", @@ -51,6 +57,25 @@ def setUpClass(cls): }, { "name": "nsa", + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--attention-backend", + "nsa", + "--nsa-prefill-backend", + "flashmla_sparse", + "--nsa-decode-backend", + "flashmla_kv", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + ], + }, + { + "name": "pure_tp", "other_args": [ "--trust-remote-code", "--tp",