From 01fe9a35d73802e0c4ef5451d7caf2cfa332b004 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Fri, 15 May 2026 11:30:01 -0700 Subject: [PATCH 1/2] Enable FlashInfer metadata support for Qwen2.5-VL vision attention Signed-off-by: Hua Huang --- vllm/model_executor/models/qwen2_5_vl.py | 85 ++++++++++++++++++++---- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 54334c91bfa6..2c72aebd14c3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -112,6 +112,7 @@ maybe_prefix, ) from .vision import ( + get_fp8_padded_hidden_size, get_vit_attn_backend, is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model, @@ -368,7 +369,8 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend + # Only used for FlashInfer CuDNN backend. + sequence_lengths: torch.Tensor | None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -425,6 +427,7 @@ def forward( dynamic_arg_dims={ "x": 0, "cu_seqlens": 0, + "sequence_lengths": 0, "rotary_pos_emb_cos": 0, "rotary_pos_emb_sin": 0, }, @@ -470,6 +473,8 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + # Only used for FlashInfer CuDNN backend. + sequence_lengths: torch.Tensor | None = None, ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -477,7 +482,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - sequence_lengths=None, + sequence_lengths=sequence_lengths, ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) @@ -597,6 +602,12 @@ def __init__( self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 + use_data_parallel = is_vit_use_data_parallel() + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, @@ -606,6 +617,11 @@ def __init__( norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads + # FP8 attention: Q/K/V become independent contiguous tensors after + # quantization, so FlashInfer cu_seqlens uses uniform stride. + self.fp8_padded_hidden_size = get_fp8_padded_hidden_size( + self.num_heads, head_dim + ) self.rotary_pos_emb = get_rope( head_size=head_dim, max_position=8192, @@ -899,21 +915,55 @@ def prepare_encoder_metadata( ) ) - # transformers - # pre-compute seqlens for window/full attn to reduce cuMemcpy operations + cu_seqlens_np = cu_seqlens.numpy() + cu_window_seqlens_np = cu_window_seqlens.numpy() + + # FlashInfer needs the real per-sequence lengths in addition to + # cu_seqlens. For other backends this returns None and is ignored. + sequence_lengths_full = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_seqlens_np, device + ) + sequence_lengths_window = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_window_seqlens_np, device + ) + + # Pre-compute max sequence lengths for window/full attention. FlashInfer + # buckets this value for cuDNN graph reuse; other backends keep the exact + # maximum. Keep the scalar on CPU because attention wrappers call .item(). if max_seqlen_override is None: - max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen_full_val = MMEncoderAttention.compute_max_seqlen( + self.attn_backend, cu_seqlens_np + ) else: - max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32) + max_seqlen_full_val = max_seqlen_override + max_seqlen_full = torch.tensor(max_seqlen_full_val, dtype=torch.int32) if max_seqlen_window_override is None: - max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens) - else: - max_seqlen_window = torch.tensor( - max_seqlen_window_override, dtype=torch.int32 + max_seqlen_window_val = MMEncoderAttention.compute_max_seqlen( + self.attn_backend, cu_window_seqlens_np ) - - cu_seqlens = cu_seqlens.to(device=device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True) + else: + max_seqlen_window_val = max_seqlen_window_override + max_seqlen_window = torch.tensor(max_seqlen_window_val, dtype=torch.int32) + + # FlashInfer uses backend-specific cu_seqlens offsets into the flattened + # Q/K/O and V buffers. Other backends receive the original cumulative + # token offsets unchanged. + cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_seqlens_np, + self.hidden_size, + self.tp_size, + device, + fp8_padded_hidden_size=self.fp8_padded_hidden_size, + ) + cu_window_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_window_seqlens_np, + self.hidden_size, + self.tp_size, + device, + fp8_padded_hidden_size=self.fp8_padded_hidden_size, + ) rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True) rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True) window_index = window_index.to(device=device, non_blocking=True) @@ -927,6 +977,8 @@ def prepare_encoder_metadata( metadata["cu_window_seqlens"] = cu_window_seqlens metadata["max_seqlen_full"] = max_seqlen_full metadata["max_seqlen_window"] = max_seqlen_window + metadata["sequence_lengths_full"] = sequence_lengths_full + metadata["sequence_lengths_window"] = sequence_lengths_window return metadata @@ -952,6 +1004,8 @@ def forward( cu_window_seqlens = encoder_metadata["cu_window_seqlens"] max_seqlen_full = encoder_metadata["max_seqlen_full"] max_seqlen_window = encoder_metadata["max_seqlen_window"] + sequence_lengths_full = encoder_metadata.get("sequence_lengths_full") + sequence_lengths_window = encoder_metadata.get("sequence_lengths_window") hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 @@ -965,9 +1019,11 @@ def forward( if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens max_seqlen_now = max_seqlen_full + sequence_lengths_now = sequence_lengths_full else: cu_seqlens_now = cu_window_seqlens max_seqlen_now = max_seqlen_window + sequence_lengths_now = sequence_lengths_window hidden_states = blk( hidden_states, @@ -975,6 +1031,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen_now, + sequence_lengths=sequence_lengths_now, ) # For Qwen2.5-VL-3B, float16 will overflow at last block @@ -1588,6 +1645,8 @@ def get_encoder_cudagraph_config(self): "cu_window_seqlens", "max_seqlen_full", "max_seqlen_window", + "sequence_lengths_full", + "sequence_lengths_window", ], out_hidden_size=self.visual.out_hidden_size, ) From 0a5aa5997ebc55b0ab501565a996f174d1ee1270 Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Sun, 17 May 2026 19:01:42 -0700 Subject: [PATCH 2/2] Minor fix following Gemini Code Assist's comment Signed-off-by: Hua Huang --- vllm/model_executor/models/qwen2_5_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2c72aebd14c3..933be6412bb2 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -915,8 +915,8 @@ def prepare_encoder_metadata( ) ) - cu_seqlens_np = cu_seqlens.numpy() - cu_window_seqlens_np = cu_window_seqlens.numpy() + cu_seqlens_np = cu_seqlens.cpu().numpy() + cu_window_seqlens_np = cu_window_seqlens.cpu().numpy() # FlashInfer needs the real per-sequence lengths in addition to # cu_seqlens. For other backends this returns None and is ignored.