Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0ccc67d
fix flash attention
vasqu Aug 7, 2025
0f97be0
i got a stroke reading that comment
vasqu Aug 7, 2025
0a45416
change dropout kwarg back to before
vasqu Aug 7, 2025
92e2075
rename _fa3... as it's used for multiple variants and should work as …
vasqu Aug 7, 2025
22574dc
Merge branch 'main' into fix-fa-integration
vasqu Aug 7, 2025
3581bd6
simplify imports and support kwargs for fa
vasqu Aug 7, 2025
2f607ba
style
vasqu Aug 7, 2025
49ce7ae
fix comments order
vasqu Aug 7, 2025
5f7d937
small fix
vasqu Aug 8, 2025
d21095c
skip kernels test (causes cuda illegal memories w/o cleanup), fix fa …
vasqu Aug 8, 2025
36bfffb
style
vasqu Aug 8, 2025
9ec8c45
Merge branch 'main' into fix-fa-integration
vasqu Aug 8, 2025
1612b56
allow fullgraph by preloading on init
vasqu Aug 8, 2025
ba8fd00
make globals "private"
vasqu Aug 8, 2025
5240985
Merge branch 'main' into fix-fa-integration
vasqu Aug 8, 2025
3dbf11a
ci pls be happy
vasqu Aug 8, 2025
d9d8ff7
change skip conditions based on backend flag (indicating missing mask…
vasqu Aug 11, 2025
ec0fbf3
move globals support to a function to prepare kwargs
vasqu Aug 11, 2025
a6996f5
style
vasqu Aug 11, 2025
ce0e586
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
86c9e81
generalize supported kwargs
vasqu Aug 11, 2025
3ae14ec
small change to doc
vasqu Aug 11, 2025
89b8f95
fix
vasqu Aug 11, 2025
24512e4
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
ec41ac5
Merge branch 'main' into fix-fa-integration
vasqu Aug 11, 2025
74a8987
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
8f91219
add comments
vasqu Aug 12, 2025
6a0a3a1
style
vasqu Aug 12, 2025
54ed29e
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
7016548
revert prep during generate
vasqu Aug 12, 2025
d9da331
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
a996bd5
style
vasqu Aug 12, 2025
07fafe1
revert weird style changes
vasqu Aug 12, 2025
a98fac4
add fa kwarg prep during generate with fixes back
vasqu Aug 12, 2025
9971d75
how did this even happen
vasqu Aug 12, 2025
5e2d35f
how
vasqu Aug 12, 2025
85cb1b1
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
4ad364c
add comment
vasqu Aug 12, 2025
cb89dbe
Merge branch 'main' into fix-fa-integration
vasqu Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,12 @@ def prepare_inputs_for_generation(
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]

cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], dim=0).to(**tensor_kws)
Comment thread
vasqu marked this conversation as resolved.
Outdated
max_length_k = int(pos.max()) + 1

bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], dim=0).to(**tensor_kws)
max_length_q = int(q_len.max())

model_inputs.update(
Expand Down
Loading