diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index f9de5b465471..07413e5a46dc 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1548,7 +1548,13 @@ def forward_decode( q_rope = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) + # Caller passed split q_nope / q_rope; we'll need to concat below if + # the chosen impl wants q_all. + q_all = None else: + # Caller passed already-concatenated q (q_all = q). Reuse it directly + # via a zero-copy view; the impl-specific blocks below will skip the + # otherwise redundant concat_mla_absorb_q_general call. q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_nope = q_all[:, :, : layer.v_head_dim] q_rope = q_all[:, :, layer.v_head_dim :] @@ -1597,7 +1603,11 @@ def forward_decode( page_table_1=page_table_1, ) elif self.nsa_decode_impl == "tilelang": - if q_rope is not None: + # Cat-skip (HIP-only): when caller passes q_rope=None on HIP, q_all + # has already been set to a zero-copy view of q in the else branch + # above and we can reuse it directly. The `not _is_hip` clause keeps + # CUDA / MUSA paths byte-identical to pre-patch by always re-cat. + if q_all is None or not _is_hip: q_all = concat_mla_absorb_q_general(q_nope, q_rope) return self._forward_tilelang( q_all=q_all, @@ -1622,7 +1632,7 @@ def forward_decode( page_size=1, ) elif self.nsa_decode_impl == "aiter": - if q_rope is not None: + if q_all is None or not _is_hip: q_all = torch.cat([q_nope, q_rope], dim=-1) return self._forward_aiter( q_all=q_all, diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py index 26fef866bdd9..efcd6cd91980 100644 --- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py +++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py @@ -376,25 +376,55 @@ def forward_absorb_core( self.rotary_emb.is_neox_style, q_out_dtype=kv_cache_dtype, ) - q_nope_fused = q_cat[..., : self.kv_lora_rank] - q_pe_fused = q_cat[..., self.kv_lora_rank :] save_kv_cache = False - if llama_4_scaling is not None: - q_nope_fused *= llama_4_scaling - attn_output = self.attn_mqa( - q_nope_fused, - None, - None, - forward_batch, - q_rope=q_pe_fused, - k_rope=k_pe_fused, - save_kv_cache=save_kv_cache, - **( - dict(topk_indices=topk_indices) - if topk_indices is not None - else {} - ), - ) + # On decode, pass q_cat directly to attn_mqa with q_rope=None so + # nsa_backend.forward_decode reuses q_cat as a zero-copy view + # (`q.contiguous().view(...)` fast-path) instead of running the + # redundant `concat_mla_absorb_q_general(q_nope_fused, q_pe_fused)` + # that would otherwise rebuild a tensor byte-identical to q_cat. + # On ROCm tilelang decode, this eliminates the + # `CatArrayBatchedCopy, ...>` kernel that used to + # fire once per layer per decode step (~2.6 us / layer saved). + # Prefill keeps the split form because nsa_backend.forward_extend + # asserts `q_rope is not None`. + if forward_batch.forward_mode.is_decode_or_idle(): + if llama_4_scaling is not None: + # llama_4_scaling applies only to the q_nope portion; + # mutate in place via the slice view of q_cat. + q_cat[..., : self.kv_lora_rank] *= llama_4_scaling + attn_output = self.attn_mqa( + q_cat, + None, + None, + forward_batch, + q_rope=None, + k_rope=k_pe_fused, + save_kv_cache=save_kv_cache, + **( + dict(topk_indices=topk_indices) + if topk_indices is not None + else {} + ), + ) + else: + q_nope_fused = q_cat[..., : self.kv_lora_rank] + q_pe_fused = q_cat[..., self.kv_lora_rank :] + if llama_4_scaling is not None: + q_nope_fused *= llama_4_scaling + attn_output = self.attn_mqa( + q_nope_fused, + None, + None, + forward_batch, + q_rope=q_pe_fused, + k_rope=k_pe_fused, + save_kv_cache=save_kv_cache, + **( + dict(topk_indices=topk_indices) + if topk_indices is not None + else {} + ), + ) else: extra_args = {} if self._fuse_rope_for_trtllm_mla(forward_batch):