Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..masking_utils import create_masks_for_generate
from ..modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
Expand Down Expand Up @@ -677,6 +678,17 @@ def prepare_inputs_for_generation(
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask

if "flash" in self.config._attn_implementation and self._supports_attention_backend:
cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=False
Comment thread
vasqu marked this conversation as resolved.
)
model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)

# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
Expand Down
88 changes: 57 additions & 31 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,46 +190,28 @@ def _upad_input(
)


def _prepare_from_posids(query, key, value, position_ids, query_length):
def prepare_fa_kwargs_from_position_ids(position_ids, is_packed_sequence: bool = True):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
This function returns all the necessary kwargs to call `flash_attn_varlen_func`
extracted from position_ids.The `position_ids` can be either packed sequence or
the usual padded position ids, for example in inference time..
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Sequence length of the input queries.
is_packed_sequence (`bool`, *optional*, defaults to `True`):
Whether the input position ids are a packed sequence or not.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
kv_length = key.shape[1]
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))

# If the lengths are not equal, most probably we are in decoding stage with cache
# In that case the position ids will not always start with `0` and we need a better way to infer
# cumulative seq lengths.
if query_length != kv_length:
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)

if not is_packed_sequence:
tensor_kws = {"dtype": torch.int32, "device": position_ids.device}
last_position_ids = position_ids[:, -1]

Expand All @@ -238,8 +220,9 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
)
max_length_k = int(last_position_ids.max()) + 1

batch_size, seq_len = query.shape[:2]
q_len = torch.ones(batch_size, **tensor_kws) if query_length == 1 else last_position_ids.add(1)
q_len = (
torch.ones(position_ids.size(0), **tensor_kws) if position_ids.shape[-1] == 1 else last_position_ids.add(1)
)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0)
max_length_q = int(q_len.max())
else:
Expand All @@ -264,6 +247,49 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max().item()
max_length_k = max_length_q
return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)


def _prepare_from_posids(query, key, value, position_ids, query_length):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Sequence length of the input queries.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
kv_length = key.shape[1]
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
is_packed_sequence = query_length == kv_length

cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa_kwargs_from_position_ids(
position_ids, is_packed_sequence=is_packed_sequence
)
return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))


Expand Down