diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 848c2a214113..fec934221467 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -222,16 +222,18 @@ def _prepare_from_posids(query, key, value, position_ids): query = query.contiguous().view(-1, query.size(-2), query.size(-1)) key = key.contiguous().view(-1, key.size(-2), key.size(-1)) value = value.contiguous().view(-1, value.size(-2), value.size(-1)) + cu_seqlens_k = torch.cat( [torch.tensor([0], dtype=torch.int32, device=query.device), position_ids[:, -1].cumsum(dim=0) + 1], dim=0 ) max_k = torch.max(position_ids, dim=1).values.max().item() + 1 + position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) cu_seq_lens = torch.cat( ( - torch.tensor([0], device=position_ids.device, dtype=torch.int32), + indices_q[position_ids == 0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) )