Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 79 additions & 74 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,20 +928,91 @@ def _handle_model_specific_adjustments(self):

hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["DeepseekV3ForCausalLM"] and not is_deepseek_nsa(hf_config):
if self.enable_piecewise_cuda_graph:
logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.")

if is_cuda() and is_sm100_supported():
if model_arch in ["DeepseekV3ForCausalLM"]:
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "trtllm_mla"
logger.info(
"Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM"
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")

if not is_npu():
self.enable_dp_attention = True
logger.warning("DP attention is enabled for DeepSeek NSA.")
if self.enable_nsa_prefill_context_parallel:
# TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation .
self.moe_dense_tp_size = 1
self.moe_a2a_backend = "deepep"
self.ep_size = self.tp_size
self.kv_cache_dtype = "bf16"
assert (
self.tp_size == 8
), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)"

logger.warning(
f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} "
)
else:
self.dp_size = self.tp_size

self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")

# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch

major, _ = torch.cuda.get_device_capability()
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"

if self.kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends for Blackwell. The default settings are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"

# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)

print_nsa_bool_env_vars()

else:
if self.enable_piecewise_cuda_graph:
logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.")

if is_cuda() and is_sm100_supported():
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "trtllm_mla"
logger.info(
"Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM"
)

# common to all Deepseek MoE models
if is_cuda() and is_sm100_supported():
# workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006
if not self.enable_dp_attention and self.nnodes == 1:
self.enable_flashinfer_allreduce_fusion = True
Expand Down Expand Up @@ -1148,72 +1219,6 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for Qwen3NextForCausalLM"
)
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")

if not is_npu():
self.enable_dp_attention = True
logger.warning("DP attention is enabled for DeepSeek NSA.")
if self.enable_nsa_prefill_context_parallel:
# TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation .
self.moe_dense_tp_size = 1
self.moe_a2a_backend = "deepep"
self.ep_size = self.tp_size
self.kv_cache_dtype = "bf16"
assert (
self.tp_size == 8
), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)"

logger.warning(
f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} "
)
else:
self.dp_size = self.tp_size

self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")

# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch

major, _ = torch.cuda.get_device_capability()
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"

if self.kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends for Blackwell. The default settings are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"

# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)

print_nsa_bool_env_vars()

def _handle_sampling_backend(self):
if self.sampling_backend is None:
Expand Down
Loading