Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 72 additions & 13 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
maybe_prefix,
)
from .vision import (
get_fp8_padded_hidden_size,
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
Expand Down Expand Up @@ -369,7 +370,8 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
# Only used for FlashInfer CuDNN backend.
sequence_lengths: torch.Tensor | None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
Expand Down Expand Up @@ -426,6 +428,7 @@ def forward(
dynamic_arg_dims={
"x": 0,
"cu_seqlens": 0,
"sequence_lengths": 0,
"rotary_pos_emb_cos": 0,
"rotary_pos_emb_sin": 0,
},
Expand Down Expand Up @@ -471,14 +474,16 @@ def forward(
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
# Only used for FlashInfer CuDNN backend.
sequence_lengths: torch.Tensor | None = None,
Comment thread
huanghua1994 marked this conversation as resolved.
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=None,
sequence_lengths=sequence_lengths,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
Expand Down Expand Up @@ -598,6 +603,12 @@ def __init__(
self.spatial_merge_size = vision_config.spatial_merge_size
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
Expand All @@ -607,6 +618,11 @@ def __init__(

norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
# FP8 attention: Q/K/V become independent contiguous tensors after
# quantization, so FlashInfer cu_seqlens uses uniform stride.
self.fp8_padded_hidden_size = get_fp8_padded_hidden_size(
Comment thread
huanghua1994 marked this conversation as resolved.
self.num_heads, head_dim
)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
max_position=8192,
Expand Down Expand Up @@ -902,21 +918,55 @@ def prepare_encoder_metadata(
)
)

# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
cu_seqlens_np = cu_seqlens.cpu().numpy()
cu_window_seqlens_np = cu_window_seqlens.cpu().numpy()

# FlashInfer needs the real per-sequence lengths in addition to
# cu_seqlens. For other backends this returns None and is ignored.
sequence_lengths_full = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens_np, device
)
sequence_lengths_window = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_window_seqlens_np, device
)

# Pre-compute max sequence lengths for window/full attention. FlashInfer
# buckets this value for cuDNN graph reuse; other backends keep the exact
# maximum. Keep the scalar on CPU because attention wrappers call .item().
if max_seqlen_override is None:
max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_full_val = MMEncoderAttention.compute_max_seqlen(
self.attn_backend, cu_seqlens_np
)
else:
max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32)
max_seqlen_full_val = max_seqlen_override
max_seqlen_full = torch.tensor(max_seqlen_full_val, dtype=torch.int32)
if max_seqlen_window_override is None:
max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
else:
max_seqlen_window = torch.tensor(
max_seqlen_window_override, dtype=torch.int32
max_seqlen_window_val = MMEncoderAttention.compute_max_seqlen(
self.attn_backend, cu_window_seqlens_np
)

cu_seqlens = cu_seqlens.to(device=device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True)
else:
max_seqlen_window_val = max_seqlen_window_override
max_seqlen_window = torch.tensor(max_seqlen_window_val, dtype=torch.int32)

# FlashInfer uses backend-specific cu_seqlens offsets into the flattened
# Q/K/O and V buffers. Other backends receive the original cumulative
# token offsets unchanged.
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens_np,
self.hidden_size,
self.tp_size,
device,
fp8_padded_hidden_size=self.fp8_padded_hidden_size,
)
cu_window_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_window_seqlens_np,
self.hidden_size,
self.tp_size,
device,
fp8_padded_hidden_size=self.fp8_padded_hidden_size,
)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True)
window_index = window_index.to(device=device, non_blocking=True)
Expand All @@ -930,6 +980,8 @@ def prepare_encoder_metadata(
metadata["cu_window_seqlens"] = cu_window_seqlens
metadata["max_seqlen_full"] = max_seqlen_full
metadata["max_seqlen_window"] = max_seqlen_window
metadata["sequence_lengths_full"] = sequence_lengths_full
metadata["sequence_lengths_window"] = sequence_lengths_window

return metadata

Expand All @@ -955,6 +1007,8 @@ def forward(
cu_window_seqlens = encoder_metadata["cu_window_seqlens"]
max_seqlen_full = encoder_metadata["max_seqlen_full"]
max_seqlen_window = encoder_metadata["max_seqlen_window"]
sequence_lengths_full = encoder_metadata.get("sequence_lengths_full")
sequence_lengths_window = encoder_metadata.get("sequence_lengths_window")

hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
Expand All @@ -968,16 +1022,19 @@ def forward(
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
sequence_lengths_now = sequence_lengths_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
sequence_lengths_now = sequence_lengths_window

hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen_now,
sequence_lengths=sequence_lengths_now,
)

# For Qwen2.5-VL-3B, float16 will overflow at last block
Expand Down Expand Up @@ -1594,6 +1651,8 @@ def get_encoder_cudagraph_config(self):
"cu_window_seqlens",
"max_seqlen_full",
"max_seqlen_window",
"sequence_lengths_full",
"sequence_lengths_window",
],
out_hidden_size=self.visual.out_hidden_size,
max_frames_per_video=max_frames,
Expand Down
Loading