Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
FA4 support cc @stas00 if you wanna play around with this PR. It's pretty much ready, just not convinced by the numbers but I also don't have quick access to a blackwell GPU (at least today :D) |
sfc-gh-sbekman
left a comment
There was a problem hiding this comment.
Thank you for working on this, Anton. Going to try it out.
To make it easier to try your PR please add to the OP how to install FA4, since it's non-trivial to find.
git clone https://github.com/Dao-AILab/flash-attention/
cd flash-attention
cd flash_attn/cute
uv build --wheel . -v --no-build-isolation --out-dir flash-attention/wheels
uv pip install flash-attention/wheels/flash_attn_cute*.whl --prerelease=allow
|
OK, gave it a test ride using your PR and the above comment's install of FA4 on B200. I did a quick test with Llama-8b and the integration worked smoothly but the tflops performance is much worse than FA2 - 2-5x slower. Not sure if it's an issue with integration or the FA4 code or the pytorch version - most likely the upstream since the integration is just a wrapper I tried pt-2.9.1-cu130 and pt-nightly-cu130 - same outcome edit: |
|
Thanks for checking this out and all the pointers @sfc-gh-sbekman ❤️
For sure, I'll add some docs for FA4 before release. Maybe also FA3 in a different PR.
Shoot, so it wasn't an GPU arch issue... This is weird
Do you have a code snippet? There are so many edge cases with sdpa that it maybe is not even entering the FA backend path? Could be quickly checked by restricting the backend usage on SDPA with their context manager with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
pass # do your thing hereI'm also unsure how FA4 is integrated in SDPA? Do we need to use a flag there? I remember that cudnn backend needed special treatment |
Did you mean that you too have observed a similar slowdown?
I was just using https://github.com/snowflakedb/ArcticTraining/ normal SFT training recipe where I tried different attention mechanisms. Just normal fwd/bwd/step - nothing special added.
They copied/adapted the FA4 kernels see: #42435 - you'd need pt nightly for that to work. |
I just did some quick numbers on inference, see the test I noted down in the PR description. I used an H100 there and as you can see it's slower (not on the same magnitude as in your samples - would say it's a mixture of model size / context size)
Gotcha, I will try to separate our implementation from the base fn of the FA library to see if our wrappers cause this or maybe some perf regression happened sometime else.
Wrong link? My assumption / hunch was that maybe
|
|
My apologies, here is the correct link pytorch/pytorch#167348 |
|
Some useful updates from talking to Tri:
|
|
Hi @sfc-gh-sbekman @vasqu Thanks for contributing! I think the varlen is supported. Do you mind testing FA4 again? |
|
just tried with the |
|
Tried again with today's FA4 and a new error is reported on H200 w/ varlen. non-varlen works, but haven't measured the performance comparison |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma3, gpt_oss, sam3 |
|
Hey @vasqu 👋 Just a quick heads-up that the API changes in this PR (shipped in the 5.4.0 release) introduced a breaking change for us. Over at lerobot, we were relying on Cheers! Second thought: |
|
@vasqu let's add BC unless we had a deprecation cycle |
|
Check out #45061, mb |
🚨 Breaking change
2.3.3and onThis is due to the fact that this is older than 2+ years (we deprecate torch in 2 year cycles for example) as well as it giving fairly high maintenance burden.
Related issues and PRs
Fixes #42405
Closes #42404 as it has a lot of unnecessary logic and tests alongside it
Testing
Sanity Testing
First quick numbers (hopper)
NOTE: FA4 is optimized for blackwell, these are quick numbers on hopper --> it's faster than FA2 but around the same on varlen, between the other FAs on non-varlen