diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 0c80774274d1..b83e0ad1595a 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -21,6 +21,7 @@ NSA_FUSE_TOPK, compute_nsa_seqlens, ) +from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_hip @@ -911,7 +912,7 @@ def forward_extend( if NSA_PREFILL_IMPL == "tilelang": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( q_all=q_all, kv_cache=kv_cache, @@ -921,7 +922,7 @@ def forward_extend( ) elif NSA_PREFILL_IMPL == "flashmla_sparse": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here, # because the flashmla_sparse kernel doesn't support fp8 compute @@ -947,7 +948,7 @@ def forward_extend( ) elif NSA_PREFILL_IMPL == "flashmla_kv": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, @@ -1031,7 +1032,7 @@ def forward_decode( if NSA_DECODE_IMPL == "flashmla_sparse": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, @@ -1041,7 +1042,7 @@ def forward_decode( ) elif NSA_DECODE_IMPL == "flashmla_kv": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, @@ -1054,7 +1055,7 @@ def forward_decode( ) elif NSA_DECODE_IMPL == "tilelang": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( q_all=q_all, kv_cache=kv_cache,