diff --git a/vllm_gaudi/models/qwen2_5_vl.py b/vllm_gaudi/models/qwen2_5_vl.py index 98f699fd76..d8b7fb97f9 100644 --- a/vllm_gaudi/models/qwen2_5_vl.py +++ b/vllm_gaudi/models/qwen2_5_vl.py @@ -77,22 +77,25 @@ def forward(cls, q, k, v, mask, cu_seqlens, qwen2_5_vl, q_block_size=64): Support long sequence at prompt phase """ q_len = q.size(-2) - lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - if mask is not None or len(lens) == 1: - if not qwen2_5_vl or (qwen2_5_vl and q_len < 65536): + if qwen2_5_vl: + if q_len < 65536: return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode) else: return AttentionLongSequence.forward(q, k, v, mask, q_block_size, cls.softmax_mode) else: - q_chunks = torch.split(q, lens, dim=2) - k_chunks = torch.split(k, lens, dim=2) - v_chunks = torch.split(v, lens, dim=2) - outputs = [] - for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode) - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=2) - return context_layer + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + if len(lens) == 1: + return FusedSDPA.apply(q, k, v, None, 0.0, False, None, cls.softmax_mode) + else: + q_chunks = torch.split(q, lens, dim=2) + k_chunks = torch.split(k, lens, dim=2) + v_chunks = torch.split(v, lens, dim=2) + outputs = [] + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): + output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode) + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=2) + return context_layer def create_block_diagonal_attention_mask(indices): @@ -237,13 +240,11 @@ def forward( seqlens: Optional[list[int]] = None, # Only used for xFormers attn_mask: Optional[torch.Tensor] = None, # Only used for HPU ) -> torch.Tensor: - mask_to_use = attn_mask if attn_mask is not None else cu_seqlens - x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, - attn_mask=mask_to_use) + attn_mask=attn_mask) x = x + self.mlp(self.norm2(x)) return x @@ -363,7 +364,8 @@ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): ) def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, - padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor) -> torch.Tensor: + padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor, + cu_seqlens: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -373,9 +375,10 @@ def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, hidden_states = blk( hidden_states, - cu_seqlens=padding_attn_mask_now, + cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, + attn_mask=padding_attn_mask_now, ) # For Qwen2.5-VL-3B, float16 will overflow at last block @@ -443,7 +446,8 @@ def get_image_embeds( rotary_pos_emb_cos=rot_pos_emb_cos, rotary_pos_emb_sin=rot_pos_emb_sin, padding_attn_mask_window=padding_attn_mask_window, - padding_attn_mask_full=padding_attn_mask_full) + padding_attn_mask_full=padding_attn_mask_full, + cu_seqlens=cu_seqlens) htcore.mark_step() # remove padding