Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions vllm_gaudi/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,25 @@ def forward(cls, q, k, v, mask, cu_seqlens, qwen2_5_vl, q_block_size=64):
Support long sequence at prompt phase
"""
q_len = q.size(-2)
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if mask is not None or len(lens) == 1:
if not qwen2_5_vl or (qwen2_5_vl and q_len < 65536):
if qwen2_5_vl:
if q_len < 65536:
return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode)
else:
return AttentionLongSequence.forward(q, k, v, mask, q_block_size, cls.softmax_mode)
else:
q_chunks = torch.split(q, lens, dim=2)
k_chunks = torch.split(k, lens, dim=2)
v_chunks = torch.split(v, lens, dim=2)
outputs = []
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=2)
return context_layer
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if len(lens) == 1:
return FusedSDPA.apply(q, k, v, None, 0.0, False, None, cls.softmax_mode)
else:
q_chunks = torch.split(q, lens, dim=2)
k_chunks = torch.split(k, lens, dim=2)
v_chunks = torch.split(v, lens, dim=2)
outputs = []
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=2)
return context_layer


def create_block_diagonal_attention_mask(indices):
Expand Down Expand Up @@ -237,13 +240,11 @@ def forward(
seqlens: Optional[list[int]] = None, # Only used for xFormers
attn_mask: Optional[torch.Tensor] = None, # Only used for HPU
) -> torch.Tensor:
mask_to_use = attn_mask if attn_mask is not None else cu_seqlens

x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
attn_mask=mask_to_use)
attn_mask=attn_mask)

x = x + self.mlp(self.norm2(x))
return x
Expand Down Expand Up @@ -363,7 +364,8 @@ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor):
)

def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor,
padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor) -> torch.Tensor:
padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor,
cu_seqlens: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.unsqueeze(1)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
Expand All @@ -373,9 +375,10 @@ def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor,

hidden_states = blk(
hidden_states,
cu_seqlens=padding_attn_mask_now,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
attn_mask=padding_attn_mask_now,
)

# For Qwen2.5-VL-3B, float16 will overflow at last block
Expand Down Expand Up @@ -443,7 +446,8 @@ def get_image_embeds(
rotary_pos_emb_cos=rot_pos_emb_cos,
rotary_pos_emb_sin=rot_pos_emb_sin,
padding_attn_mask_window=padding_attn_mask_window,
padding_attn_mask_full=padding_attn_mask_full)
padding_attn_mask_full=padding_attn_mask_full,
cu_seqlens=cu_seqlens)
htcore.mark_step()

# remove padding
Expand Down