Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--sampling-backend` | Choose the kernels for sampling layers. | None |
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
| `--mm-attention-backend` | Set multimodal attention backend. | None |
| `--nsa-prefill-backend` | Prefill attention implementation for nsa backend. | `flashmla_sparse` |
| `--nsa-decode-backend` | Decode attention implementation for nsa backend. | `flashmla_kv` |

## Speculative decoding

Expand Down
40 changes: 19 additions & 21 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
)


_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"]

NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
Expand Down Expand Up @@ -181,8 +179,8 @@ def __init__(
self.req_to_token = model_runner.req_to_token_pool.req_to_token

global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend

self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)

Expand Down Expand Up @@ -336,7 +334,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
if NSA_DECODE_IMPL == "flashmla_kv"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
Expand Down Expand Up @@ -383,7 +381,7 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
),
seq_len_q=1,
)
if NSA_DECODE_IMPL == "flashmla_decode"
if NSA_DECODE_IMPL == "flashmla_kv"
else None
),
}
Expand Down Expand Up @@ -421,7 +419,7 @@ def init_forward_metadata_capture_cuda_graph(

seqlens_expanded = cache_seqlens_int32
nsa_extend_seq_lens_list = [1] * num_tokens
if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, num_tokens + 1))
Expand Down Expand Up @@ -478,7 +476,7 @@ def init_forward_metadata_capture_cuda_graph(
)
nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens

if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
Expand Down Expand Up @@ -534,7 +532,7 @@ def init_forward_metadata_capture_cuda_graph(
)
nsa_extend_seq_lens_list = [1] * bs

if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs * self.speculative_num_draft_tokens + 1))
Expand Down Expand Up @@ -712,7 +710,7 @@ def init_forward_metadata_replay_cuda_graph(
else:
assert metadata.real_page_table is metadata.page_table_1

if NSA_DECODE_IMPL == "flashmla_decode":
if NSA_DECODE_IMPL == "flashmla_kv":
flashmla_metadata = metadata.flashmla_metadata.slice(
slice(0, seqlens_expanded_size + 1)
)
Expand Down Expand Up @@ -803,20 +801,20 @@ def forward_extend(
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_prefill":
elif NSA_PREFILL_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
elif NSA_PREFILL_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
Expand Down Expand Up @@ -897,20 +895,20 @@ def forward_decode(
page_size=1,
)

if NSA_DECODE_IMPL == "flashmla_prefill":
if NSA_DECODE_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
elif NSA_DECODE_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
Expand Down Expand Up @@ -998,7 +996,7 @@ def _forward_fa3(
)
return o # type: ignore

def _forward_flashmla_prefill(
def _forward_flashmla_sparse(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
Expand All @@ -1017,7 +1015,7 @@ def _forward_flashmla_prefill(
)
return o

def _forward_flashmla_decode(
def _forward_flashmla_kv(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
Expand Down
20 changes: 10 additions & 10 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@

DEFAULT_LORA_EVICTION_POLICY = "lru"

NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]

RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]

Expand Down Expand Up @@ -324,8 +324,8 @@ class ServerArgs:
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
nsa_prefill: str = "flashmla_prefill"
nsa_decode: str = "fa3"
nsa_prefill_backend: str = "flashmla_sparse"
nsa_decode_backend: str = "fa3"

# Speculative decoding
enable_beta_spec: bool = False
Expand Down Expand Up @@ -1024,10 +1024,10 @@ def _handle_model_specific_adjustments(self):
logger.warning("Setting KV cache dtype to fp8.")

if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill = "flashmla_decode"
self.nsa_decode = "flashmla_decode"
self.nsa_prefill_backend = "flashmla_kv"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_decode for FP8 KV Cache."
"Setting NSA backend to flashmla_kv for FP8 KV Cache."
)

# Logging env vars for NSA
Expand Down Expand Up @@ -2356,14 +2356,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Set multimodal attention backend.",
)
parser.add_argument(
"--nsa-prefill",
default=ServerArgs.nsa_prefill,
"--nsa-prefill-backend",
default=ServerArgs.nsa_prefill_backend,
type=str,
choices=NSA_CHOICES,
)
parser.add_argument(
"--nsa-decode",
default=ServerArgs.nsa_decode,
"--nsa-decode-backend",
default=ServerArgs.nsa_decode_backend,
type=str,
choices=NSA_CHOICES,
)
Expand Down
Loading