diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index dd3754f1b4bb..862488e5f918 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1297,6 +1297,7 @@ def forward_extend( cos_sin_cache, is_neox, llama_4_scaling, + is_prefill=True, ) if k is not None: @@ -1929,6 +1930,7 @@ def _forward_trtllm( cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, llama_4_scaling: Optional[torch.Tensor] = None, + is_prefill: bool = False, ) -> torch.Tensor: """Forward using TRT-LLM sparse MLA kernel.""" import flashinfer.decode @@ -1990,6 +1992,13 @@ def _forward_trtllm( if envs.SGLANG_NSA_FUSE_TOPK.get(): page_table_1 = topk_indices + elif is_prefill: + page_table_1 = transform_index_page_table_prefill( + page_table=metadata.page_table_1, + topk_indices=topk_indices, + extend_lens_cpu=metadata.nsa_extend_seq_lens_list, + page_size=1, + ) else: page_table_1 = transform_index_page_table_decode( page_table=metadata.page_table_1, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0186be08a387..11bf2ba6c71c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1455,9 +1455,6 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: if self.dp_size == 1 and major >= 10: self.nsa_prefill_backend = "trtllm" self.nsa_decode_backend = "trtllm" - logger.warning( - "Flashmla is not supported on Blackwell device without DP attention. Set NSA prefill/decode backends to trtllm, which runs fast but loses a little accuracy." - ) else: # flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics if not user_set_prefill: @@ -1526,14 +1523,6 @@ def _handle_model_specific_adjustments(self): logger.warning( f"Set dense attention kv len threshold to model index_topk={envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.get()} for DeepSeek with DSA." ) - if self.nsa_prefill_backend == "trtllm": - # We temporarily set the threshold to 128k to avoid IMA error. Should be removed after supporting flashmla prefill impl with trtllm decode impl. - envs.SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD.set( - 128 * 1024 - ) - logger.warning( - "TRTLLM sparse MLA kernel requires MHA as prefill impl, the threshold for dense attention is overridden. This will be fixed in the future." - ) if self.is_attention_backend_not_set(): self.attention_backend = "nsa" logger.info("Use nsa attention backend for DeepSeek with DSA.") diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index a467d3fb409f..7553242a101e 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -20,10 +20,10 @@ def get_thinking_kwargs(args): thinking_mode = getattr(args, "thinking_mode", None) if thinking_mode in THINKING_MODE_CHOICES: - if thinking_mode == "deepseek-v3": + if thinking_mode in ["deepseek-v3", "kimi-k2"]: thinking_param = "thinking" else: - # Qwen3 + # All models other than dpsk v3/kimi_k2 thinking_param = "enable_thinking" return {thinking_param: True} return {} @@ -267,7 +267,7 @@ def run_eval(args): return metrics -THINKING_MODE_CHOICES = ["deepseek-v3", "qwen3"] +THINKING_MODE_CHOICES = ["deepseek-v3", "qwen-3", "glm-45", "kimi-k2"] if __name__ == "__main__": parser = argparse.ArgumentParser()