diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index fcd9f5f0d1a2..84b9dc53a1d7 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -296,7 +296,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -320,27 +319,51 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -361,7 +384,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -369,7 +391,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -467,25 +488,6 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -515,14 +517,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 48cf2a7bb220..1c25215f3447 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -603,25 +603,6 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -651,14 +632,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index aab42610461c..d0d957b8eccb 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -956,7 +956,6 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -969,27 +968,51 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -1023,14 +1046,12 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -1190,25 +1211,6 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -1256,12 +1258,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index b3d4ae90e886..a792cd8a5403 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1934,7 +1934,6 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -1947,27 +1946,51 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -1984,14 +2007,12 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -2006,25 +2027,6 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> super().__init__(config, *inputs, **kwargs) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -2072,12 +2074,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 97d7791faafc..0285d18cd2fe 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -215,7 +215,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -239,27 +238,51 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -280,7 +303,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -288,7 +310,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -422,25 +443,6 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -488,12 +490,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f18e0b3461aa..07f41356e800 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -159,7 +159,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -167,7 +166,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -289,25 +287,6 @@ def get_window_index(self, grid_thw): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -355,12 +334,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2cd1a61b80fb..a8b2ebf1a9fe 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -333,7 +333,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -357,27 +356,51 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -400,7 +423,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -408,7 +430,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -721,25 +742,6 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @auto_docstring def forward( self, @@ -765,14 +767,12 @@ def forward( dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/tests/models/glm4v/test_modeling_glm4v.py b/tests/models/glm4v/test_modeling_glm4v.py index 39b66875c2d4..e5211e59f2ed 100644 --- a/tests/models/glm4v/test_modeling_glm4v.py +++ b/tests/models/glm4v/test_modeling_glm4v.py @@ -419,7 +419,7 @@ def test_small_model_integration_test_batch_different_resolutions(self): output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = [ - "\nWhat kind of dog is this?\nGot it, let's look at the image. The animal in the picture has a stocky build, thick fur, and a face that's", + "\nWhat kind of dog is this?\nGot it, let's look at the image. The animal in the picture is not a dog; it's a cat. Specifically, it looks", "\nWhat kind of dog is this?\nGot it, let's look at the image. Wait, the animals here are cats, not dogs. The question is about a dog, but" ] # fmt: skip self.assertEqual(