Skip to content
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
cd4c7cb
use partial to wrap around `transformers` utils!
ArthurZucker Jul 17, 2025
005f482
try to refactor?
ArthurZucker Jul 17, 2025
1b834a4
revert one wrong change
ArthurZucker Jul 17, 2025
d93f366
just a nit
ArthurZucker Jul 17, 2025
2b7d411
push
ArthurZucker Jul 17, 2025
affba20
reverter watever was wrong!
ArthurZucker Jul 17, 2025
1959eb2
some nits
ArthurZucker Jul 17, 2025
888cd40
fixes when there is no attention mask
ArthurZucker Jul 17, 2025
8f5e62b
Merge branch 'main' of github.com:huggingface/transformers into kerne…
ArthurZucker Jul 17, 2025
5a7ae11
bring the licence back
ArthurZucker Jul 17, 2025
c57673b
some fixes
ArthurZucker Jul 17, 2025
7d69d83
nit
ArthurZucker Jul 17, 2025
7e94910
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 17, 2025
112e2a6
style
ArthurZucker Jul 17, 2025
501aa7e
remove prints
ArthurZucker Jul 17, 2025
04088be
correct dtype
ArthurZucker Jul 17, 2025
b1e104b
fa flags for testing
vasqu Jul 17, 2025
7087e7b
update
ArthurZucker Jul 17, 2025
cc58aca
Merge branch 'main' into kernels-flash-attn
ArthurZucker Jul 17, 2025
6a2996a
use paged attention if requested!
ArthurZucker Jul 18, 2025
8ddc525
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 18, 2025
a586294
updates
ArthurZucker Jul 18, 2025
57842f5
a clone was needed, not sure why
ArthurZucker Jul 18, 2025
43b7f32
automatically create cu seq lens when input is flash, this at least m…
ArthurZucker Jul 18, 2025
12bad1b
simplify and improve?
ArthurZucker Jul 18, 2025
c0b600a
flash attention is kinda broken on recent cuda version so allow the o…
ArthurZucker Jul 21, 2025
5c64874
Merge branch 'main' into kernels-flash-attn
ArthurZucker Jul 21, 2025
11e5000
fix!
ArthurZucker Jul 21, 2025
1c07350
protect kernels import
ArthurZucker Jul 21, 2025
cdaa1eb
update
ArthurZucker Jul 22, 2025
767d585
properly parse generation config being passed
ArthurZucker Jul 22, 2025
10f866e
Merge branch 'kernels-flash-attn' of github.com:huggingface/transform…
ArthurZucker Jul 22, 2025
c75c539
revert and update
ArthurZucker Jul 22, 2025
a2f3126
add two tests
ArthurZucker Jul 22, 2025
63b01c3
Merge branch 'main' of github.com:huggingface/transformers into kerne…
ArthurZucker Jul 22, 2025
85829d7
some fixes
ArthurZucker Jul 22, 2025
56981a5
fix test FA2
ArthurZucker Jul 22, 2025
b3f7a49
takes comment into account
ArthurZucker Jul 22, 2025
21e07f7
fixup
ArthurZucker Jul 22, 2025
a8b7ec6
revert changes
ArthurZucker Jul 22, 2025
f111d33
revert the clone, it is only needed because the metal kernel is not d…
ArthurZucker Jul 22, 2025
cd98c1f
[docs] update attention implementation and cache docs (#39547)
zucchini-nlp Jul 22, 2025
f457a08
fix mps on our side for now
ArthurZucker Jul 22, 2025
38d241b
Update src/transformers/integrations/flash_paged.py
ArthurZucker Jul 22, 2025
cb58187
Merge branches 'main' and 'kernels-flash-attn' of github.com:huggingf…
ArthurZucker Jul 22, 2025
c0f4f09
no qa
ArthurZucker Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/generation/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,8 @@ def __init__(
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None
self.do_sample = getattr(generation_config, "do_sample", True)
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
generation_config = model.generation_config if generation_config is None else generation_config
self.logit_processor = self.model._get_logits_processor(generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,24 @@ def prepare_inputs_for_generation(
if encoder_attention_mask is not None:
model_inputs["attention_mask"] = encoder_attention_mask

if "flash" in self.config._attn_implementation and self._supports_attention_backend:
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]

cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
max_length_k = int(pos.max()) + 1

bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
max_length_q = int(q_len.max())

model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
if key not in model_inputs:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def flash_attention_forward(
"FlashAttention does not support inputs with dim=0.\n"
"Please check your input shapes or use SDPA instead."
)

# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand Down Expand Up @@ -76,6 +75,7 @@ def flash_attention_forward(
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
attn_implementation=module.config._attn_implementation,
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
**kwargs,
)

Expand Down
13 changes: 8 additions & 5 deletions src/transformers/integrations/flash_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
pass
Comment thread
ArthurZucker marked this conversation as resolved.
Outdated


def paged_attention_forward(
Expand All @@ -20,6 +20,7 @@ def paged_attention_forward(
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
implementation=None,
**kwargs,
) -> torch.Tensor:
r"""Perform the forward pass of attention with paged key-value cache.
Expand All @@ -46,12 +47,14 @@ def paged_attention_forward(
"""
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)

if implementation is not None:
flash_attn_varlen_func = implementation.flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
q.transpose(1, 2).squeeze(0),
k.transpose(1, 2).squeeze(0),
v.transpose(1, 2).squeeze(0),
q.transpose(1, 2).squeeze(0).contiguous(),
k.transpose(1, 2).squeeze(0).contiguous(),
v.transpose(1, 2).squeeze(0).contiguous(),
cumulative_seqlens_q.to(torch.int32),
cumulative_seqlens_k.to(torch.int32),
cumulative_seqlens_k.to(torch.int32).clone(),
max_seqlen_q,
Comment thread
ArthurZucker marked this conversation as resolved.
max_seqlen_k,
softmax_scale=module.scaling,
Expand Down
Loading
Loading