Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e82beeb
initial implementation
vasqu Nov 26, 2025
1a45b34
Merge branch 'main' into fa4-support
vasqu Nov 27, 2025
30c6682
CB support
vasqu Nov 27, 2025
e9cdeea
change how we call item on max_seq_len_q/k
vasqu Nov 27, 2025
40168b4
fix
vasqu Nov 27, 2025
91a1b3b
tests
vasqu Nov 27, 2025
8d3dc6c
fix fa2 clash
vasqu Nov 27, 2025
bf1d589
unify the fa dispatch
vasqu Nov 27, 2025
f5b7f9c
fix
vasqu Nov 27, 2025
6288f44
modernbert...
vasqu Nov 27, 2025
15ed2eb
oops
vasqu Nov 27, 2025
6be5bbe
parity test
vasqu Nov 28, 2025
dad1b04
style
vasqu Nov 28, 2025
34c15c2
nit
vasqu Nov 28, 2025
ac8e309
Merge branch 'main' into fa4-support
vasqu Mar 6, 2026
776a1af
fixup imports for fa4
vasqu Mar 6, 2026
cc7a1b7
enable attention sinks, fixup logits checks in parity test
vasqu Mar 6, 2026
ca26ecf
Merge branch 'main' into fa4-support
vasqu Mar 6, 2026
65912b2
style
vasqu Mar 6, 2026
d07749f
change dispatch logic and introduce lower bound for FA
vasqu Mar 9, 2026
95d644e
Merge branch 'main' into fa4-support
vasqu Mar 9, 2026
ed88dcc
style
vasqu Mar 9, 2026
7fba6df
fix test
vasqu Mar 9, 2026
27acafe
min fa2, avoid 2x device sync
vasqu Mar 9, 2026
ba9fb59
Merge branch 'main' into fa4-support
vasqu Mar 9, 2026
afa0940
style
vasqu Mar 9, 2026
7223fe6
simple min version instead of list
vasqu Mar 10, 2026
da88dcf
fixup error message on non init check
vasqu Mar 10, 2026
654db43
fixup up non init check a tad more
vasqu Mar 10, 2026
a690e55
Merge branch 'main' into fa4-support
vasqu Mar 13, 2026
d3485da
refactor some FA constants out to main fa utils
vasqu Mar 13, 2026
476789f
new marker for all fas needed
vasqu Mar 13, 2026
19e4c44
oops
vasqu Mar 13, 2026
08445b6
style and make the fa kernel fallback generalized
vasqu Mar 13, 2026
920bef7
default none...
vasqu Mar 13, 2026
8ee8c56
more refactors
vasqu Mar 13, 2026
cd2a9b3
style
vasqu Mar 13, 2026
27e0d58
fix
vasqu Mar 13, 2026
043f11f
this test faulty even on main, xformers can handle any shape apparent…
vasqu Mar 13, 2026
b0485b5
lets make this more robust, we should check for none within...
vasqu Mar 13, 2026
eae216e
fix
vasqu Mar 13, 2026
15f6ba9
oops
vasqu Mar 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality")
config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality")
config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality")
config.addinivalue_line("markers", "flash_attn_4_test: mark test which tests flash attention 4 functionality")

os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_4_test: marks tests related to flash attention 4 (deselect with '-m \"not flash_attn_4_test\"')",
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class AttentionMaskInterface(GeneralInterface):
"eager": eager_mask,
"flash_attention_2": flash_attention_mask,
"flash_attention_3": flash_attention_mask,
"flash_attention_4": flash_attention_mask,
"flex_attention": flex_attention_mask,
}

Expand Down
43 changes: 33 additions & 10 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_4_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
is_torch_xpu_available,
Expand All @@ -34,7 +35,7 @@

# TODO Deprecate when all models have the attention interface
def flash_attn_supports_top_left_mask():
if is_flash_attn_3_available():
if is_flash_attn_3_available() or is_flash_attn_4_available():
return False
if is_flash_attn_2_available():
return not is_flash_attn_greater_or_equal_2_10()
Expand All @@ -47,7 +48,8 @@ def flash_attn_supports_top_left_mask():
# TODO Deprecate when all models have the attention interface
def is_flash_attn_available():
return (
is_flash_attn_3_available()
is_flash_attn_4_available()
or is_flash_attn_3_available()
or is_flash_attn_2_available()
or is_torch_npu_available()
or is_torch_xpu_available()
Expand Down Expand Up @@ -82,10 +84,13 @@ def _lazy_imports(implementation: Optional[str]):
"""
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
is_fa4 = is_flash_attn_4_available()

pad_input, unpad_input = _pad_input, _unpad_input

if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
if (implementation == "flash_attention_2" and is_fa2) or (
implementation is None and is_fa2 and not is_fa3 and not is_fa4
):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
elif is_torch_npu_available():
Expand All @@ -94,8 +99,10 @@ def _lazy_imports(implementation: Optional[str]):
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
else:
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
if implementation == "flash_attention_3" or (implementation is None and is_fa3 and not is_fa4):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
elif implementation == "flash_attention_4" or (implementation is None and is_fa4):
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
# Kernels fallback
else:
flash_attn_func = getattr(implementation, "flash_attn_func", None)
Expand Down Expand Up @@ -467,6 +474,8 @@ def _process_flash_attention_kwargs(
softcap: Optional[float] = None,
deterministic: Optional[bool] = None,
s_aux: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
supports_mapping: Optional[dict[str, bool]] = None,
**kwargs,
):
Expand Down Expand Up @@ -497,6 +506,10 @@ def _process_flash_attention_kwargs(
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
s_aux (`torch.Tensor`, *optional*):
Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
max_seqlen_q (`int`, *optional*):
The maximum sequence length in the query tensor during a varlen forward.
max_seqlen_k (`int`, *optional*):
The maximum sequence length in the key/value tensor during a varlen forward.
Return:
flash_kwargs (`dict`):
A dict of kwargs that are requested and supported.
Expand Down Expand Up @@ -529,6 +542,12 @@ def _process_flash_attention_kwargs(
if supports_mapping["s_aux"] and s_aux is not None:
flash_kwargs["s_aux"] = s_aux

if supports_mapping["max_seqlen_q"] and max_seqlen_q is not None:
flash_kwargs["max_seqlen_q"] = max_seqlen_q

if supports_mapping["max_seqlen_k"] and max_seqlen_k is not None:
flash_kwargs["max_seqlen_k"] = max_seqlen_k

return flash_kwargs


Expand Down Expand Up @@ -583,7 +602,8 @@ def _flash_attention_forward(
)

# Extract the flash attention kwargs that have been requested (and are supported by the implementation)
flash_kwargs = process_flash_kwargs_fn(
flash_kwargs = partial(
process_flash_kwargs_fn,
query_length=query_length,
key_length=key_states.size(1),
is_causal=is_causal,
Expand Down Expand Up @@ -619,15 +639,15 @@ def _flash_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

# Newer fa versions no longer accept `max_seqlen_(q|k)`
final_flash_kwargs = flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k)
out_unpad = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
**final_flash_kwargs,
)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
Expand All @@ -650,6 +670,8 @@ def _flash_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

# Newer fa versions no longer accept `max_seqlen_(q|k)`
final_flash_kwargs = flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k)
out = flash_varlen_fn(
q,
k,
Expand All @@ -658,7 +680,7 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
**final_flash_kwargs,
)
if isinstance(out, tuple):
out = out[0]
Expand All @@ -667,7 +689,8 @@ def _flash_attention_forward(

# No padding
else:
out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
final_flash_kwargs = flash_kwargs()
out = flash_fn(query_states, key_states, value_states, **final_flash_kwargs)
if isinstance(out, tuple):
out = out[0]

Expand Down
Loading