Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,8 @@ def rot_pos_emb(
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined, pos_ids

def compute_attn_mask_seqlen(
Expand Down
9 changes: 2 additions & 7 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,13 +738,8 @@ def rotary_pos_emb_thw(self, t, h, w):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)

cos_combined = cos_combined.reshape(
cos_combined.shape[0] // self.spatial_merge_unit,
Expand Down
9 changes: 2 additions & 7 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,13 +724,8 @@ def rot_pos_emb(
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined

def compute_attn_mask_seqlen(
Expand Down
9 changes: 2 additions & 7 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,8 @@ def rot_pos_emb(self, grid_thw):
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)

return cos_combined, sin_combined

Expand Down
17 changes: 3 additions & 14 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,18 +459,13 @@ def rot_pos_emb(self, grid_thw: list[list[int]]):
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0)
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)

# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
cos_w = cos[pos_ids[:, 1]]
sin_h = sin[pos_ids[:, 0]]
sin_w = sin[pos_ids[:, 1]]

cos_combined = torch.cat([cos_h, cos_w], dim=-1)
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)

return cos_combined, sin_combined

Expand Down Expand Up @@ -566,12 +561,6 @@ def forward(
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb_cos = rotary_pos_emb_cos.to(
hidden_states.device, non_blocking=True
)
rotary_pos_emb_sin = rotary_pos_emb_sin.to(
hidden_states.device, non_blocking=True
)

cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
Expand Down