Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,22 @@ def _flash_attention_forward(
query_states, key_states, value_states, target_dtype
)

# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# under two cases:
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
is_fa2_with_position_ids = (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
)
is_fa2_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
Comment on lines +518 to +525
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unsure about this since we would allow cu_seq and max_seq only but on most models we also have RoPE so it's breaking those models silently if we eff up not passing position_ids (due to RoPE positions being bound to position_ids as well). We should imo add at least a warning on only varlen kwargs to give some discretion here.

On another note, what do the integration tests use? Are they still working as expected 👀 seems a bit sus

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration tests always use position_ids which is the first is_fa2_with_position_ids, I just copied it from existing code and moved up here. The second case is added for Qwen only, afaik no other model passes pre-computed cu_lens form attention layers

In qwen we don't need any position ids, because they are 3D and won't help at all in inferring cu_lens

Copy link
Contributor

@vasqu vasqu Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the general integration tests like e.g.

def test_small_model_integration_test_batch_flashatt2(self):

For why, I'm concerned about the second case on qwen is future model additions and general usability, not the validity of qwen. For developers, is_fa2_with_varlen_kwargs indicates that this suffices for varlen - before, we (unintentionally) checked for the existence of (correct flattened) position ids that RoPE models need when using varlen. Maybe #35941 helps for reference on what I mean.

Imo, it would help to add at least comments that for varlen most models need correct flattened position ids (from e.g. a collator), especially RoPE models which make up the majority of newer models.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, am I right that users might be passing only cu_lens without correct position_ids? I believe that would be users' responsibility to take care that RoPE is applied correctly, but I will add a comment in code explaining it, sure

In slow integration tests we don't pass position_ids, not that of I know. For most LLMs the fa2 path integration tests fallback to inferring cu_lens from the mask, and in Qwen the position_ids are constructed on-the-fly during forward call. The model has a requirement for adding rope deltas on top of 3D positions and I don't think users would be doing all that manually

Copy link
Contributor

@vasqu vasqu Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Hmm, am I right that users might be passing only cu_lens without correct position_ids?" - Yes, not only users but possibly us as well because it's something that's harder to figure out when done wrong imo :D


# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
Expand All @@ -531,14 +547,7 @@ def _flash_attention_forward(
)
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
):
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
batch_size = query_states.size(0)

if cu_seq_lens_q is None or cu_seq_lens_k is None:
Expand Down
79 changes: 57 additions & 22 deletions src/transformers/models/glm4v/modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,15 @@ def eager_attention_forward(
class Glm4vVisionAttention(nn.Module):
def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__()
self.config = config
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = 1
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = config.attention_dropout
self.is_causal = False

def forward(
Expand All @@ -295,23 +296,31 @@ def forward(
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)

cos, sin = position_embeddings
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)

query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)

attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
Expand All @@ -322,13 +331,17 @@ def forward(
query_states,
key_states,
value_states,
attention_mask,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.squeeze(0)

attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
Expand All @@ -348,13 +361,15 @@ def forward(
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
Expand Down Expand Up @@ -452,6 +467,25 @@ def rot_pos_emb(self, grid_thw):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids

def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None

seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask

def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -481,14 +515,15 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)

for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)

hidden_states = self.post_layernorm(hidden_states)

Expand Down
86 changes: 28 additions & 58 deletions src/transformers/models/glm4v/modular_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLRotaryEmbedding,
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionAttention,
Qwen2_5_VLVisionBlock,
apply_rotary_pos_emb_vision,
)
from ..qwen2_5_vl.processing_qwen2_5_vl import (
Qwen2_5_VLProcessor,
Expand Down Expand Up @@ -505,62 +505,12 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc
return embeddings


class Glm4vVisionAttention(nn.Module):
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__()
self.config = config
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = 1
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.is_causal = False

def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)

query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)

attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.squeeze(0)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output


class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
Expand Down Expand Up @@ -653,6 +603,25 @@ def rot_pos_emb(self, grid_thw):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids

def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None

seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask

def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -682,14 +651,15 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)

for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)

hidden_states = self.post_layernorm(hidden_states)

Expand Down
Loading