Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ def forward_extend(
cos_sin_cache,
is_neox,
llama_4_scaling,
is_prefill=True,
)

if k is not None:
Expand Down Expand Up @@ -1929,6 +1930,7 @@ def _forward_trtllm(
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
is_prefill: bool = False,
) -> torch.Tensor:
"""Forward using TRT-LLM sparse MLA kernel."""
import flashinfer.decode
Expand Down Expand Up @@ -1990,6 +1992,13 @@ def _forward_trtllm(

if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices
elif is_prefill:
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
Expand Down
11 changes: 0 additions & 11 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,9 +1455,6 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str:
if self.dp_size == 1 and major >= 10:
self.nsa_prefill_backend = "trtllm"
self.nsa_decode_backend = "trtllm"
logger.warning(
"Flashmla is not supported on Blackwell device without DP attention. Set NSA prefill/decode backends to trtllm, which runs fast but loses a little accuracy."
)
else:
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
if not user_set_prefill:
Expand Down Expand Up @@ -1526,14 +1523,6 @@ def _handle_model_specific_adjustments(self):
logger.warning(
f"Set dense attention kv len threshold to model index_topk={envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.get()} for DeepSeek with DSA."
)
if self.nsa_prefill_backend == "trtllm":
# We temporarily set the threshold to 128k to avoid IMA error. Should be removed after supporting flashmla prefill impl with trtllm decode impl.
envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.set(
128 * 1024
)
logger.warning(
"TRTLLM sparse MLA kernel requires MHA as prefill impl, the threshold for dense attention is overridden. This will be fixed in the future."
)
if self.is_attention_backend_not_set():
self.attention_backend = "nsa"
logger.info("Use nsa attention backend for DeepSeek with DSA.")
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/test/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
def get_thinking_kwargs(args):
thinking_mode = getattr(args, "thinking_mode", None)
if thinking_mode in THINKING_MODE_CHOICES:
if thinking_mode == "deepseek-v3":
if thinking_mode in ["deepseek-v3", "kimi-k2"]:
thinking_param = "thinking"
else:
# Qwen3
# All models other than dpsk v3/kimi_k2
thinking_param = "enable_thinking"
return {thinking_param: True}
return {}
Expand Down Expand Up @@ -267,7 +267,7 @@ def run_eval(args):
return metrics


THINKING_MODE_CHOICES = ["deepseek-v3", "qwen3"]
THINKING_MODE_CHOICES = ["deepseek-v3", "qwen-3", "glm-45", "kimi-k2"]

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
Loading