diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a4c75ab2e03b..0dec414304cc 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -113,6 +113,7 @@ maybe_prefix, ) from .vision import ( + get_fp8_padded_hidden_size, get_vit_attn_backend, is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model, @@ -369,7 +370,8 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend + # Only used for FlashInfer CuDNN backend. + sequence_lengths: torch.Tensor | None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -426,6 +428,7 @@ def forward( dynamic_arg_dims={ "x": 0, "cu_seqlens": 0, + "sequence_lengths": 0, "rotary_pos_emb_cos": 0, "rotary_pos_emb_sin": 0, }, @@ -471,6 +474,8 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention + # Only used for FlashInfer CuDNN backend. + sequence_lengths: torch.Tensor | None = None, ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -478,7 +483,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - sequence_lengths=None, + sequence_lengths=sequence_lengths, ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) @@ -598,6 +603,12 @@ def __init__( self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 + use_data_parallel = is_vit_use_data_parallel() + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, @@ -607,6 +618,11 @@ def __init__( norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads + # FP8 attention: Q/K/V become independent contiguous tensors after + # quantization, so FlashInfer cu_seqlens uses uniform stride. + self.fp8_padded_hidden_size = get_fp8_padded_hidden_size( + self.num_heads, head_dim + ) self.rotary_pos_emb = get_rope( head_size=head_dim, max_position=8192, @@ -902,21 +918,55 @@ def prepare_encoder_metadata( ) ) - # transformers - # pre-compute seqlens for window/full attn to reduce cuMemcpy operations + cu_seqlens_np = cu_seqlens.cpu().numpy() + cu_window_seqlens_np = cu_window_seqlens.cpu().numpy() + + # FlashInfer needs the real per-sequence lengths in addition to + # cu_seqlens. For other backends this returns None and is ignored. + sequence_lengths_full = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_seqlens_np, device + ) + sequence_lengths_window = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_window_seqlens_np, device + ) + + # Pre-compute max sequence lengths for window/full attention. FlashInfer + # buckets this value for cuDNN graph reuse; other backends keep the exact + # maximum. Keep the scalar on CPU because attention wrappers call .item(). if max_seqlen_override is None: - max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen_full_val = MMEncoderAttention.compute_max_seqlen( + self.attn_backend, cu_seqlens_np + ) else: - max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32) + max_seqlen_full_val = max_seqlen_override + max_seqlen_full = torch.tensor(max_seqlen_full_val, dtype=torch.int32) if max_seqlen_window_override is None: - max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens) - else: - max_seqlen_window = torch.tensor( - max_seqlen_window_override, dtype=torch.int32 + max_seqlen_window_val = MMEncoderAttention.compute_max_seqlen( + self.attn_backend, cu_window_seqlens_np ) - - cu_seqlens = cu_seqlens.to(device=device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True) + else: + max_seqlen_window_val = max_seqlen_window_override + max_seqlen_window = torch.tensor(max_seqlen_window_val, dtype=torch.int32) + + # FlashInfer uses backend-specific cu_seqlens offsets into the flattened + # Q/K/O and V buffers. Other backends receive the original cumulative + # token offsets unchanged. + cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_seqlens_np, + self.hidden_size, + self.tp_size, + device, + fp8_padded_hidden_size=self.fp8_padded_hidden_size, + ) + cu_window_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_window_seqlens_np, + self.hidden_size, + self.tp_size, + device, + fp8_padded_hidden_size=self.fp8_padded_hidden_size, + ) rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True) rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True) window_index = window_index.to(device=device, non_blocking=True) @@ -930,6 +980,8 @@ def prepare_encoder_metadata( metadata["cu_window_seqlens"] = cu_window_seqlens metadata["max_seqlen_full"] = max_seqlen_full metadata["max_seqlen_window"] = max_seqlen_window + metadata["sequence_lengths_full"] = sequence_lengths_full + metadata["sequence_lengths_window"] = sequence_lengths_window return metadata @@ -955,6 +1007,8 @@ def forward( cu_window_seqlens = encoder_metadata["cu_window_seqlens"] max_seqlen_full = encoder_metadata["max_seqlen_full"] max_seqlen_window = encoder_metadata["max_seqlen_window"] + sequence_lengths_full = encoder_metadata.get("sequence_lengths_full") + sequence_lengths_window = encoder_metadata.get("sequence_lengths_window") hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 @@ -968,9 +1022,11 @@ def forward( if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens max_seqlen_now = max_seqlen_full + sequence_lengths_now = sequence_lengths_full else: cu_seqlens_now = cu_window_seqlens max_seqlen_now = max_seqlen_window + sequence_lengths_now = sequence_lengths_window hidden_states = blk( hidden_states, @@ -978,6 +1034,7 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen_now, + sequence_lengths=sequence_lengths_now, ) # For Qwen2.5-VL-3B, float16 will overflow at last block @@ -1594,6 +1651,8 @@ def get_encoder_cudagraph_config(self): "cu_window_seqlens", "max_seqlen_full", "max_seqlen_window", + "sequence_lengths_full", + "sequence_lengths_window", ], out_hidden_size=self.visual.out_hidden_size, max_frames_per_video=max_frames,