diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 65ae9c4dc6eb..7e28341dc373 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -207,6 +207,7 @@ class TRTLLMMLAPrefillMetadata: max_seq_len: int cum_seq_lens: torch.Tensor seq_lens: torch.Tensor + fallback_to_flashinfer_impl: bool = False @dataclass @@ -551,7 +552,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - if self.disable_chunked_prefix_cache: + # For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend + # when chunked prefix cache is disabled. + has_prefix = any(forward_batch.extend_prefix_lens_cpu) + fallback_to_flashinfer_impl = ( + self.disable_chunked_prefix_cache and has_prefix + ) + if fallback_to_flashinfer_impl: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -566,6 +573,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seq_len, cum_seq_lens_q, seq_lens, + fallback_to_flashinfer_impl, ) elif ( forward_batch.forward_mode.is_decode_or_idle() @@ -897,6 +905,15 @@ def forward_extend( cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, ) -> torch.Tensor: + + if ( + self.forward_prefill_metadata is not None + and self.forward_prefill_metadata.fallback_to_flashinfer_impl + ): + return super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( @@ -1021,9 +1038,8 @@ def forward_extend( if k_rope is not None: k = torch.cat([k, k_rope], dim=-1) k = k.view(-1, layer.tp_k_head_num, layer.head_dim) - v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) - + # When chunked prefix cache is enabled, dispatch to different path for ragged attention. if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None