Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class TRTLLMMLAPrefillMetadata:
max_seq_len: int
cum_seq_lens: torch.Tensor
seq_lens: torch.Tensor
fallback_to_flashinfer_impl: bool = False


@dataclass
Expand Down Expand Up @@ -551,7 +552,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
if self.disable_chunked_prefix_cache:
# For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend
# when chunked prefix cache is disabled.
has_prefix = any(forward_batch.extend_prefix_lens_cpu)
fallback_to_flashinfer_impl = (
self.disable_chunked_prefix_cache and has_prefix
)
if fallback_to_flashinfer_impl:
super().init_forward_metadata(forward_batch)

seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
Expand All @@ -566,6 +573,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
max_seq_len,
cum_seq_lens_q,
seq_lens,
fallback_to_flashinfer_impl,
)
elif (
forward_batch.forward_mode.is_decode_or_idle()
Expand Down Expand Up @@ -897,6 +905,15 @@ def forward_extend(
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
) -> torch.Tensor:

if (
self.forward_prefill_metadata is not None
and self.forward_prefill_metadata.fallback_to_flashinfer_impl
):
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)

# TODO refactor to avoid code duplication
merge_query = q_rope is not None
if (
Expand Down Expand Up @@ -1021,9 +1038,8 @@ def forward_extend(
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)

v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)

# When chunked prefix cache is enabled, dispatch to different path for ragged attention.
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
Expand Down
Loading