From 09afb56c1c6807b2af8c8fb2a54d9964397d9193 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 00:39:32 +0000 Subject: [PATCH 1/6] Fix trtllm mla backend when chunked prefix cache is disabled --- .../sglang/srt/layers/attention/trtllm_mla_backend.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 1882881e5d7f..1f6647977d8a 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -860,6 +860,13 @@ def forward_extend( cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, ) -> torch.Tensor: + + # When chunked prefix cache is disabled, fallback to normal MLA path + if self.disable_chunked_prefix_cache: + return super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( @@ -1003,6 +1010,10 @@ def forward_extend( output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output + # When chunked prefix cache is enabled, dispatch to different path for ragged attention. + assert ( + not self.disable_chunked_prefix_cache + ), "Chunked prefix cache should be enabled when using ragged attention." if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None From 6c06bb5e598717b514fc5fc3186bac374a2e709a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:01:26 +0000 Subject: [PATCH 2/6] upd --- .../srt/layers/attention/trtllm_mla_backend.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 1f6647977d8a..644088b0c891 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -284,6 +284,9 @@ def __init__( self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + # Whether to fallback to flashinfer MLA kernel + self.fallback_to_flashinfer_mla = False + def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -516,7 +519,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - if self.disable_chunked_prefix_cache: + # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + self.fallback_to_flashinfer_mla = ( + self.disable_chunked_prefix_cache and not extend_no_prefix + ) + if self.fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -537,6 +545,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): or forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): + self.fallback_to_flashinfer_mla = False bs = forward_batch.batch_size # Get maximum sequence length. @@ -583,6 +592,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: + self.fallback_to_flashinfer_mla = True return super().init_forward_metadata(forward_batch) def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): @@ -861,8 +871,7 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - # When chunked prefix cache is disabled, fallback to normal MLA path - if self.disable_chunked_prefix_cache: + if self.fallback_to_flashinfer_mla: return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) @@ -1011,9 +1020,6 @@ def forward_extend( return output # When chunked prefix cache is enabled, dispatch to different path for ragged attention. - assert ( - not self.disable_chunked_prefix_cache - ), "Chunked prefix cache should be enabled when using ragged attention." if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None From 75b5c4a1b161ba6f2839021b708f28886ca9a611 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:34:25 +0000 Subject: [PATCH 3/6] fix --- .../srt/layers/attention/trtllm_mla_backend.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 644088b0c891..9fd802941248 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -207,6 +207,7 @@ class TRTLLMMLAPrefillMetadata: max_seq_len: int cum_seq_lens: torch.Tensor seq_lens: torch.Tensor + fallback_to_flashinfer_mla: bool = False @dataclass @@ -284,9 +285,6 @@ def __init__( self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - # Whether to fallback to flashinfer MLA kernel - self.fallback_to_flashinfer_mla = False - def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -521,10 +519,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ): # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - self.fallback_to_flashinfer_mla = ( + fallback_to_flashinfer_mla = ( self.disable_chunked_prefix_cache and not extend_no_prefix ) - if self.fallback_to_flashinfer_mla: + if fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -539,13 +537,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seq_len, cum_seq_lens_q, seq_lens, + fallback_to_flashinfer_mla, ) elif ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - self.fallback_to_flashinfer_mla = False bs = forward_batch.batch_size # Get maximum sequence length. @@ -592,7 +590,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata else: - self.fallback_to_flashinfer_mla = True return super().init_forward_metadata(forward_batch) def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): @@ -871,7 +868,7 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - if self.fallback_to_flashinfer_mla: + if self.forward_prefill_metadata.fallback_to_flashinfer_mla: return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) From 8024c1e12568eb7b9b22285e495afca5a4f1e34e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 01:36:20 +0000 Subject: [PATCH 4/6] fix --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 9fd802941248..7d5306e1c38f 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -518,9 +518,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. - extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + has_prefix = any(forward_batch.extend_prefix_lens_cpu) fallback_to_flashinfer_mla = ( - self.disable_chunked_prefix_cache and not extend_no_prefix + self.disable_chunked_prefix_cache and has_prefix ) if fallback_to_flashinfer_mla: super().init_forward_metadata(forward_batch) From 4b74679ba6bbb6ac7c1015f3ac9c195aa1fd6324 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 30 Oct 2025 02:14:33 +0000 Subject: [PATCH 5/6] fix --- python/sglang/srt/layers/attention/trtllm_mla_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7d5306e1c38f..4f80ab46efcc 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -868,7 +868,10 @@ def forward_extend( is_neox: Optional[bool] = False, ) -> torch.Tensor: - if self.forward_prefill_metadata.fallback_to_flashinfer_mla: + if ( + self.forward_prefill_metadata is not None + and self.forward_prefill_metadata.fallback_to_flashinfer_mla + ): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) From 16c5640d92793bcdfcd26c597b30dd6e63632055 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 5 Nov 2025 13:16:49 -0800 Subject: [PATCH 6/6] upd --- .../srt/layers/attention/trtllm_mla_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7dcf98adef79..6fdbcefd61b6 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -207,7 +207,7 @@ class TRTLLMMLAPrefillMetadata: max_seq_len: int cum_seq_lens: torch.Tensor seq_lens: torch.Tensor - fallback_to_flashinfer_mla: bool = False + fallback_to_flashinfer_impl: bool = False @dataclass @@ -552,12 +552,13 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): - # For extend batch with prefix length > 0, fallback to flashinfer MLA kernel when chunked prefix cache is disabled. + # For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend + # when chunked prefix cache is disabled. has_prefix = any(forward_batch.extend_prefix_lens_cpu) - fallback_to_flashinfer_mla = ( + fallback_to_flashinfer_impl = ( self.disable_chunked_prefix_cache and has_prefix ) - if fallback_to_flashinfer_mla: + if fallback_to_flashinfer_impl: super().init_forward_metadata(forward_batch) seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens @@ -572,7 +573,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_seq_len, cum_seq_lens_q, seq_lens, - fallback_to_flashinfer_mla, + fallback_to_flashinfer_impl, ) elif ( forward_batch.forward_mode.is_decode_or_idle() @@ -907,7 +908,7 @@ def forward_extend( if ( self.forward_prefill_metadata is not None - and self.forward_prefill_metadata.fallback_to_flashinfer_mla + and self.forward_prefill_metadata.fallback_to_flashinfer_impl ): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope