-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[DeepseekV32]: use _concat_mla_absorb_q_general to replace torch.cat
#12215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d927518
3859ef8
ebbaa9a
d6e8e1d
bce7513
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| NSA_FUSE_TOPK, | ||
| compute_nsa_seqlens, | ||
| ) | ||
| from sglang.srt.layers.attention.trtllm_mla_backend import _concat_mla_absorb_q_general | ||
| from sglang.srt.layers.dp_attention import get_attention_tp_size | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode | ||
| from sglang.srt.utils import is_hip | ||
|
|
@@ -911,7 +912,7 @@ def forward_extend( | |
|
|
||
| if NSA_PREFILL_IMPL == "tilelang": | ||
| if q_rope is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change replaces Given the potential for widespread impact, this is a critical issue. q_all = _concat_mla_absorb_q_general(q_nope, q_rope) # Ensure this function is equivalent to torch.cat in all cases |
||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
| return self._forward_tilelang( | ||
| q_all=q_all, | ||
| kv_cache=kv_cache, | ||
|
|
@@ -921,7 +922,7 @@ def forward_extend( | |
| ) | ||
| elif NSA_PREFILL_IMPL == "flashmla_sparse": | ||
| if q_rope is not None: | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
|
|
||
| # NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here, | ||
| # because the flashmla_sparse kernel doesn't support fp8 compute | ||
|
|
@@ -947,7 +948,7 @@ def forward_extend( | |
| ) | ||
| elif NSA_PREFILL_IMPL == "flashmla_kv": | ||
| if q_rope is not None: | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
| return self._forward_flashmla_kv( | ||
| q_all=q_all, | ||
| kv_cache=kv_cache, | ||
|
|
@@ -1031,7 +1032,7 @@ def forward_decode( | |
|
|
||
| if NSA_DECODE_IMPL == "flashmla_sparse": | ||
| if q_rope is not None: | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
| return self._forward_flashmla_sparse( | ||
| q_all=q_all, | ||
| kv_cache=kv_cache, | ||
|
|
@@ -1041,7 +1042,7 @@ def forward_decode( | |
| ) | ||
| elif NSA_DECODE_IMPL == "flashmla_kv": | ||
| if q_rope is not None: | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
| return self._forward_flashmla_kv( | ||
| q_all=q_all, | ||
| kv_cache=kv_cache, | ||
|
|
@@ -1054,7 +1055,7 @@ def forward_decode( | |
| ) | ||
| elif NSA_DECODE_IMPL == "tilelang": | ||
| if q_rope is not None: | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1) | ||
| q_all = _concat_mla_absorb_q_general(q_nope, q_rope) | ||
| return self._forward_tilelang( | ||
| q_all=q_all, | ||
| kv_cache=kv_cache, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.