Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,13 @@ def forward_decode(
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
# Caller passed split q_nope / q_rope; we'll need to concat below if
# the chosen impl wants q_all.
q_all = None
else:
# Caller passed already-concatenated q (q_all = q). Reuse it directly
# via a zero-copy view; the impl-specific blocks below will skip the
# otherwise redundant concat_mla_absorb_q_general call.
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
Expand Down Expand Up @@ -1597,7 +1603,11 @@ def forward_decode(
page_table_1=page_table_1,
)
elif self.nsa_decode_impl == "tilelang":
if q_rope is not None:
# Cat-skip (HIP-only): when caller passes q_rope=None on HIP, q_all
# has already been set to a zero-copy view of q in the else branch
# above and we can reuse it directly. The `not _is_hip` clause keeps
# CUDA / MUSA paths byte-identical to pre-patch by always re-cat.
if q_all is None or not _is_hip:
q_all = concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_tilelang(
q_all=q_all,
Expand All @@ -1622,7 +1632,7 @@ def forward_decode(
page_size=1,
)
elif self.nsa_decode_impl == "aiter":
if q_rope is not None:
if q_all is None or not _is_hip:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_aiter(
q_all=q_all,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,55 @@ def forward_absorb_core(
self.rotary_emb.is_neox_style,
q_out_dtype=kv_cache_dtype,
)
q_nope_fused = q_cat[..., : self.kv_lora_rank]
q_pe_fused = q_cat[..., self.kv_lora_rank :]
save_kv_cache = False
if llama_4_scaling is not None:
q_nope_fused *= llama_4_scaling
attn_output = self.attn_mqa(
q_nope_fused,
None,
None,
forward_batch,
q_rope=q_pe_fused,
k_rope=k_pe_fused,
save_kv_cache=save_kv_cache,
**(
dict(topk_indices=topk_indices)
if topk_indices is not None
else {}
),
)
# On decode, pass q_cat directly to attn_mqa with q_rope=None so
# nsa_backend.forward_decode reuses q_cat as a zero-copy view
# (`q.contiguous().view(...)` fast-path) instead of running the
# redundant `concat_mla_absorb_q_general(q_nope_fused, q_pe_fused)`
# that would otherwise rebuild a tensor byte-identical to q_cat.
# On ROCm tilelang decode, this eliminates the
# `CatArrayBatchedCopy<OpaqueType<1u>, ...>` kernel that used to
# fire once per layer per decode step (~2.6 us / layer saved).
# Prefill keeps the split form because nsa_backend.forward_extend
# asserts `q_rope is not None`.
if forward_batch.forward_mode.is_decode_or_idle():
if llama_4_scaling is not None:
# llama_4_scaling applies only to the q_nope portion;
# mutate in place via the slice view of q_cat.
q_cat[..., : self.kv_lora_rank] *= llama_4_scaling
attn_output = self.attn_mqa(
q_cat,
None,
None,
forward_batch,
q_rope=None,
k_rope=k_pe_fused,
save_kv_cache=save_kv_cache,
**(
dict(topk_indices=topk_indices)
if topk_indices is not None
else {}
),
)
else:
q_nope_fused = q_cat[..., : self.kv_lora_rank]
q_pe_fused = q_cat[..., self.kv_lora_rank :]
if llama_4_scaling is not None:
q_nope_fused *= llama_4_scaling
attn_output = self.attn_mqa(
q_nope_fused,
None,
None,
forward_batch,
q_rope=q_pe_fused,
k_rope=k_pe_fused,
save_kv_cache=save_kv_cache,
**(
dict(topk_indices=topk_indices)
if topk_indices is not None
else {}
),
)
else:
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
Expand Down
Loading