diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 2fa46c58f788..fc3c67ecab19 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -123,6 +123,27 @@ class TopkTransformMethod(IntEnum): RAGGED = auto() +@torch.compile +def _compiled_cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor: + return torch.cat(tensors, dim=dim) + + +def _cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor: + """ + Concatenate two tensors along the last dimension. + Use this function to concatenate q_nope and q_rope or k_nope and k_rope. + """ + assert len(tensors) == 2 + + qk_nope, qk_rope = tensors + assert qk_nope.ndim == 3 and qk_rope.ndim == 3 + + torch._dynamo.mark_dynamic(qk_nope, 0) + torch._dynamo.mark_dynamic(qk_rope, 0) + + return _compiled_cat([qk_nope, qk_rope], dim=dim) + + @dataclass(frozen=True) class NSAIndexerMetadata(BaseIndexerMetadata): attn_metadata: NSAMetadata @@ -942,7 +963,7 @@ def forward_extend( kv_cache, page_table_1_flattened ) else: - kv_cache = torch.cat([k, k_rope], dim=-1) + kv_cache = _cat([k, k_rope], dim=-1) page_table_1 = topk_indices return self._forward_flashmla_sparse(