From d92751861c526b77510ee7d7d46028e1fad1292c Mon Sep 17 00:00:00 2001 From: Guangda Liu Date: Mon, 27 Oct 2025 12:55:13 +0000 Subject: [PATCH] dsv32: use _concat_mla_absorb_q_general to replace torch.cat --- python/sglang/srt/layers/attention/nsa_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 7da15cc47089..b65798b325be 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -18,6 +18,7 @@ NSA_FUSE_TOPK, compute_nsa_seqlens, ) +from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_hip @@ -793,7 +794,7 @@ def forward_extend( ) if NSA_PREFILL_IMPL == "tilelang": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( q_all=q_all, kv_cache=kv_cache, @@ -803,7 +804,7 @@ def forward_extend( ) elif NSA_PREFILL_IMPL == "flashmla_sparse": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, @@ -813,7 +814,7 @@ def forward_extend( ) elif NSA_PREFILL_IMPL == "flashmla_kv": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, @@ -897,7 +898,7 @@ def forward_decode( if NSA_DECODE_IMPL == "flashmla_sparse": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, @@ -907,7 +908,7 @@ def forward_decode( ) elif NSA_DECODE_IMPL == "flashmla_kv": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, @@ -920,7 +921,7 @@ def forward_decode( ) elif NSA_DECODE_IMPL == "tilelang": if q_rope is not None: - q_all = torch.cat([q_nope, q_rope], dim=-1) + q_all = _concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( q_all=q_all, kv_cache=kv_cache,