From c4b221c631f40040e44d6326d0f94e277e61a629 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 18 Nov 2025 15:50:51 +0000 Subject: [PATCH 1/2] [Model][QwenVL] Simplify cos/sin rotary embedding indexing Signed-off-by: Lukas Geiger --- vllm/model_executor/models/glm4_1v.py | 9 ++------- vllm/model_executor/models/qwen2_5_vl.py | 9 ++------- vllm/model_executor/models/qwen2_vl.py | 9 ++------- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 9 ++------- vllm/model_executor/models/qwen3_vl.py | 9 ++------- 5 files changed, 10 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 65c3fc2d9e97..ed409b42fb69 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -797,13 +797,8 @@ def rot_pos_emb( # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined, pos_ids def compute_attn_mask_seqlen( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2e4fd9645d88..5b5d50ec8935 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -738,13 +738,8 @@ def rotary_pos_emb_thw(self, t, h, w): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) cos_combined = cos_combined.reshape( cos_combined.shape[0] // self.spatial_merge_unit, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 53df5972a8fe..cda8eaf5377f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -724,13 +724,8 @@ def rot_pos_emb( # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined def compute_attn_mask_seqlen( diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 8274b92138f7..d2fd74a5e41a 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -428,13 +428,8 @@ def rot_pos_emb(self, grid_thw): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 99a4007ef7f2..1669cf78a2c2 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -464,13 +464,8 @@ def rot_pos_emb(self, grid_thw: list[list[int]]): # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] - - cos_combined = torch.cat([cos_h, cos_w], dim=-1) - sin_combined = torch.cat([sin_h, sin_w], dim=-1) + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined From cc1f0c5e07fd361852fec7cad272c8b746c36d61 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 18 Nov 2025 19:29:29 +0000 Subject: [PATCH 2/2] [Model][Qwen3VL] Prevent synchronous CPU-GPU copy Signed-off-by: Lukas Geiger --- vllm/model_executor/models/qwen3_vl.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1669cf78a2c2..0c546309400b 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -459,7 +459,7 @@ def rot_pos_emb(self, grid_thw: list[list[int]]): else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) for t, h, w in grid_thw ] - pos_ids = torch.cat(pos_ids, dim=0) + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) @@ -561,12 +561,6 @@ def forward( pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) - rotary_pos_emb_cos = rotary_pos_emb_cos.to( - hidden_states.device, non_blocking=True - ) - rotary_pos_emb_sin = rotary_pos_emb_sin.to( - hidden_states.device, non_blocking=True - ) cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]