Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
NSA_FUSE_TOPK,
compute_nsa_seqlens,
)
from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
Expand Down Expand Up @@ -911,7 +912,7 @@ def forward_extend(

if NSA_PREFILL_IMPL == "tilelang":
if q_rope is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change replaces torch.cat with _concat_mla_absorb_q_general. It's crucial to ensure that _concat_mla_absorb_q_general is functionally equivalent to torch.cat in all scenarios where it's being used. If _concat_mla_absorb_q_general does not handle all cases correctly, it could lead to incorrect results or unexpected behavior. Can you provide more details on the scenarios where _concat_mla_absorb_q_general is guaranteed to work correctly and any limitations it might have?

Given the potential for widespread impact, this is a critical issue.

                q_all = _concat_mla_absorb_q_general(q_nope, q_rope) # Ensure this function is equivalent to torch.cat in all cases

q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
Expand All @@ -921,7 +922,7 @@ def forward_extend(
)
elif NSA_PREFILL_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)

# NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
# because the flashmla_sparse kernel doesn't support fp8 compute
Expand All @@ -947,7 +948,7 @@ def forward_extend(
)
elif NSA_PREFILL_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def forward_decode(

if NSA_DECODE_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
Expand All @@ -1041,7 +1042,7 @@ def forward_decode(
)
elif NSA_DECODE_IMPL == "flashmla_kv":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_flashmla_kv(
q_all=q_all,
kv_cache=kv_cache,
Expand All @@ -1054,7 +1055,7 @@ def forward_decode(
)
elif NSA_DECODE_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
q_all = _concat_mla_absorb_q_general(q_nope, q_rope)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
Expand Down
Loading