diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4d546d79be80..792201428cc3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -928,20 +928,91 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_hf_config() model_arch = hf_config.architectures[0] - if model_arch in ["DeepseekV3ForCausalLM"] and not is_deepseek_nsa(hf_config): - if self.enable_piecewise_cuda_graph: - logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.") - - if is_cuda() and is_sm100_supported(): + if model_arch in ["DeepseekV3ForCausalLM"]: + if is_deepseek_nsa(hf_config): if ( self.attention_backend is None and self.prefill_attention_backend is None and self.decode_attention_backend is None ): - self.attention_backend = "trtllm_mla" - logger.info( - "Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM" + self.attention_backend = "nsa" + logger.warning("Set nsa attention backend for DeepSeek NSA.") + + if not is_npu(): + self.enable_dp_attention = True + logger.warning("DP attention is enabled for DeepSeek NSA.") + if self.enable_nsa_prefill_context_parallel: + # TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation . + self.moe_dense_tp_size = 1 + self.moe_a2a_backend = "deepep" + self.ep_size = self.tp_size + self.kv_cache_dtype = "bf16" + assert ( + self.tp_size == 8 + ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" + + logger.warning( + f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} " + ) + else: + self.dp_size = self.tp_size + + self.page_size = 64 + logger.warning("Setting page size to 64 for DeepSeek NSA.") + + # For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently + import torch + + major, _ = torch.cuda.get_device_capability() + if self.kv_cache_dtype == "auto": + self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16" + logger.warning( + f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA." + ) + if self.kv_cache_dtype == "bf16": + self.kv_cache_dtype = "bfloat16" + assert self.kv_cache_dtype in [ + "bfloat16", + "fp8_e4m3", + ], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype" + + if self.kv_cache_dtype == "fp8_e4m3": + # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics + self.nsa_prefill_backend = "flashmla_auto" + self.nsa_decode_backend = "flashmla_kv" + logger.warning( + "Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache." + ) + else: + # set prefill/decode backends for Blackwell. The default settings are for Hopper. + if major >= 10: + self.nsa_prefill_backend = "flashmla_sparse" + self.nsa_decode_backend = "flashmla_sparse" + + # Logging env vars for NSA + from sglang.srt.layers.attention.nsa.utils import ( + print_nsa_bool_env_vars, ) + + print_nsa_bool_env_vars() + + else: + if self.enable_piecewise_cuda_graph: + logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.") + + if is_cuda() and is_sm100_supported(): + if ( + self.attention_backend is None + and self.prefill_attention_backend is None + and self.decode_attention_backend is None + ): + self.attention_backend = "trtllm_mla" + logger.info( + "Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM" + ) + + # common to all Deepseek MoE models + if is_cuda() and is_sm100_supported(): # workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006 if not self.enable_dp_attention and self.nnodes == 1: self.enable_flashinfer_allreduce_fusion = True @@ -1148,72 +1219,6 @@ def _handle_model_specific_adjustments(self): logger.info( "Use flashinfer_trtllm as MoE runner backend on sm100 for Qwen3NextForCausalLM" ) - if is_deepseek_nsa(hf_config): - if ( - self.attention_backend is None - and self.prefill_attention_backend is None - and self.decode_attention_backend is None - ): - self.attention_backend = "nsa" - logger.warning("Set nsa attention backend for DeepSeek NSA.") - - if not is_npu(): - self.enable_dp_attention = True - logger.warning("DP attention is enabled for DeepSeek NSA.") - if self.enable_nsa_prefill_context_parallel: - # TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation . - self.moe_dense_tp_size = 1 - self.moe_a2a_backend = "deepep" - self.ep_size = self.tp_size - self.kv_cache_dtype = "bf16" - assert ( - self.tp_size == 8 - ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" - - logger.warning( - f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} " - ) - else: - self.dp_size = self.tp_size - - self.page_size = 64 - logger.warning("Setting page size to 64 for DeepSeek NSA.") - - # For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently - import torch - - major, _ = torch.cuda.get_device_capability() - if self.kv_cache_dtype == "auto": - self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16" - logger.warning( - f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA." - ) - if self.kv_cache_dtype == "bf16": - self.kv_cache_dtype = "bfloat16" - assert self.kv_cache_dtype in [ - "bfloat16", - "fp8_e4m3", - ], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype" - - if self.kv_cache_dtype == "fp8_e4m3": - # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics - self.nsa_prefill_backend = "flashmla_auto" - self.nsa_decode_backend = "flashmla_kv" - logger.warning( - "Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache." - ) - else: - # set prefill/decode backends for Blackwell. The default settings are for Hopper. - if major >= 10: - self.nsa_prefill_backend = "flashmla_sparse" - self.nsa_decode_backend = "flashmla_sparse" - - # Logging env vars for NSA - from sglang.srt.layers.attention.nsa.utils import ( - print_nsa_bool_env_vars, - ) - - print_nsa_bool_env_vars() def _handle_sampling_backend(self): if self.sampling_backend is None: