Skip to content

🚨 [FA4] Initial support#42435

Merged
vasqu merged 42 commits intohuggingface:mainfrom
vasqu:fa4-support
Mar 13, 2026
Merged

🚨 [FA4] Initial support#42435
vasqu merged 42 commits intohuggingface:mainfrom
vasqu:fa4-support

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented Nov 26, 2025

🚨 Breaking change

  • FA2 is only supported from version 2.3.3 and on

This is due to the fact that this is older than 2+ years (we deprecate torch in 2 year cycles for example) as well as it giving fairly high maintenance burden.

Related issues and PRs

Fixes #42405
Closes #42404 as it has a lot of unnecessary logic and tests alongside it

Testing

Sanity Testing

RUN_SLOW=1 pytest tests/models/llama/test_modeling_llama.py -k flash
  • Passes all flash attention 4 tests

First quick numbers (hopper)

# No attention mask (base fa)
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
Latency:
    With FA2: 381.5204345703125
    With FA3: 362.461669921875
    With FA4: 373.788427734375

# With attention mask (varlen fa)
Latency:
    With FA2: 509.337646484375
    With FA3: 476.020654296875
    With FA4: 476.72578125

NOTE: FA4 is optimized for blackwell, these are quick numbers on hopper --> it's faster than FA2 but around the same on varlen, between the other FAs on non-varlen

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Nov 28, 2025

FA4 support cc @stas00 if you wanna play around with this PR. It's pretty much ready, just not convinced by the numbers but I also don't have quick access to a blackwell GPU (at least today :D)

Copy link
Copy Markdown
Contributor

@sfc-gh-sbekman sfc-gh-sbekman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this, Anton. Going to try it out.

To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.

git clone https://github.com/Dao-AILab/flash-attention/
cd flash-attention
cd flash_attn/cute
uv build --wheel . -v --no-build-isolation --out-dir flash-attention/wheels
uv pip install flash-attention/wheels/flash_attn_cute*.whl --prerelease=allow

@sfc-gh-sbekman
Copy link
Copy Markdown
Contributor

sfc-gh-sbekman commented Dec 1, 2025

OK, gave it a test ride using your PR and the above comment's install of FA4 on B200.

I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper

I tried pt-2.9.1-cu130 and pt-nightly-cu130 - same outcome

edit: sdpa in pt-nightly, which supposedly backported FA4, is about 3x faster than FA4 on its own using the same llama-8b - since they both should be using the same code, perhaps there is an issue with the integration?

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Dec 1, 2025

Thanks for checking this out and all the pointers @sfc-gh-sbekman ❤️

To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.

For sure, I'll add some docs for FA4 before release. Maybe also FA3 in a different PR.

I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper

Shoot, so it wasn't an GPU arch issue... This is weird

sdpa in pt-nightly, which supposedly backported FA4, is about 3x faster than FA4 on its own using the same llama-8b - since they both should be using the same code, perhaps there is an issue with the integration?

Do you have a code snippet? There are so many edge cases with sdpa that it maybe is not even entering the FA backend path? Could be quickly checked by restricting the backend usage on SDPA with their context manager

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
    pass  # do your thing here

I'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 2, 2025

Shoot, so it wasn't an GPU arch issue... This is weird

Did you mean that you too have observed a similar slowdown?

Do you have a code snippet?

I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.

I'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment

They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work.

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Dec 2, 2025

Did you mean that you too have observed a similar slowdown?

I just did some quick numbers on inference, see the test I noted down in the PR description. I used an H100 there and as you can see it's slower (not on the same magnitude as in your samples - would say it's a mixture of model size / context size)

I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.

Gotcha, I will try to separate our implementation from the base fn of the FA library to see if our wrappers cause this or maybe some perf regression happened sometime else.

They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work.

Wrong link? My assumption / hunch was that maybe

  1. Even nightly might not use FA4 per default and sticks to FA2 per default, i.e. might need some extra flags to enable that specific backend. But that's just my feeling, need to look into it.
  2. Our implementations has some issues where attention masks are created even when it is not needed (full (causal) attention). If a mask is passed to SDPA, then the FA backend can never be entered per their restrictions. So I thought, maybe, we have this case (SDPA with xformers faster than FA4 - xformers is not so bad on short contexts <2k).

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Dec 2, 2025

My apologies, here is the correct link pytorch/pytorch#167348

@sfc-gh-sbekman
Copy link
Copy Markdown
Contributor

sfc-gh-sbekman commented Dec 12, 2025

Some useful updates from talking to Tri:

  • FA4 is supposed to replace FA2 and FA3 and would work with A/H/B archs
  • FA4 varlen support is planned in a few weeks time

@edixiong
Copy link
Copy Markdown

edixiong commented Feb 5, 2026

Hi @sfc-gh-sbekman @vasqu Thanks for contributing! I think the varlen is supported. Do you mind testing FA4 again?

@sfc-gh-sbekman
Copy link
Copy Markdown
Contributor

just tried with the main version of fa - bwd doesn't work:

[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 325, in backward
[rank3]:     torch.autograd.backward(outputs_with_grad, args_with_grad)
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
[rank3]:     return user_fn(self, *args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/interface.py", line 1385, in backward
[rank3]:     dq, dk, dv = _flash_attn_bwd(
[rank3]:                  ^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/interface.py", line 1059, in _flash_attn_bwd
[rank3]:     _flash_attn_bwd.compile_cache[compile_key] = cute.compile(
[rank3]:                                                  ^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/cute_dsl_utils.py", line 118, in cute_compile_patched
[rank3]:     output = cute_compile_og(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 802, in __call__
[rank3]:     ).launch(
[rank3]:   ^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 1213, in kernel
[rank3]:     if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 1215, in then_block_16
[rank3]:     self.compute_loop(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 2062, in compute_loop
[rank3]:     while work_tile.is_valid_tile:
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 2317, in if_region_4
[rank3]:     if process_tile:
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 2339, in then_block_5
[rank3]:     consumer_state_dKV = self.epilogue_dK_or_dV_tma(
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 2877, in epilogue_dK_or_dV_tma
[rank3]:     tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
[rank3]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py", line 1834, in split_wg
[rank3]:     t = cute.logical_divide(
[rank3]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/_mlir/dialects/_cute_ops_gen.py", line 1805, in __init__
[rank3]:     super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: ValueError: Operation creation failed
loc("t = cute.logical_divide("("/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/flash_bwd_sm100.py":1834:16)): error: failed to perform a valid division of '!cute.layout<"(((32,32),1),1,1,1):(((1,65536),0),0,0,0)">' by #cute.tile<"[1024:1;1:0;1:0;0:1]">

@sfc-gh-sbekman
Copy link
Copy Markdown
Contributor

sfc-gh-sbekman commented Feb 24, 2026

Tried again with today's FA4 and a new error is reported on H200 w/ varlen.

[rank6]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/torch/autograd/function.py", line 311, in apply
[rank6]:     return user_fn(self, *args)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/interface.py", line 1395, in backward
[rank6]:     dq, dk, dv = _flash_attn_bwd(
[rank6]:                  ^^^^^^^^^^^^^^^^
[rank6]:   File "/home/yak/miniconda3/envs/dev/lib/python3.12/site-packages/flash_attn/cute/interface.py", line 617, in _flash_attn_bwd
[rank6]:     assert not is_varlen, "varlen backward is not yet supported on sm90"
[rank6]:            ^^^^^^^^^^^^^
[rank6]: AssertionError: varlen backward is not yet supported on sm90

non-varlen works, but haven't measured the performance comparison

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty!

@vasqu vasqu changed the title [FA4] Initial support 🚨 [FA4] Initial support Mar 13, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma3, gpt_oss, sam3

@vasqu vasqu added this pull request to the merge queue Mar 13, 2026
Merged via the queue into huggingface:main with commit 65db6fc Mar 13, 2026
28 checks passed
@vasqu vasqu deleted the fa4-support branch March 13, 2026 19:32
@imstevenpmwork
Copy link
Copy Markdown
Contributor

imstevenpmwork commented Mar 27, 2026

Hey @vasqu 👋 Just a quick heads-up that the API changes in this PR (shipped in the 5.4.0 release) introduced a breaking change for us.

Over at lerobot, we were relying on is_flash_attn_greater_or_equal_2_10, which looks like it got removed during the import_utils.py refactor. Just wanted to flag this to save some debugging time in case anyone else is suddenly staring at a red CI!

Cheers!

Second thought:
Maybe we can add it into the Release Notes. Users should move from is_flash_attn_greater_or_equal_2_10 to is_flash_attn_greater_or_equal("2_10")

@ArthurZucker
Copy link
Copy Markdown
Collaborator

ArthurZucker commented Mar 27, 2026

@vasqu let's add BC unless we had a deprecation cycle

@vasqu
Copy link
Copy Markdown
Contributor Author

vasqu commented Mar 27, 2026

Check out #45061, mb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Integrate FA4 (Flash Attention for Blackwell) into HF Transformers

9 participants