Add --use-flash-attention flag.#7223
Conversation
|
(For testing, I recommend the @gel-crabs branch |
01ff77e to
4d26925
Compare
|
one of the ruff fails is my fault but can you fix the other one? |
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
|
better? |
* Add --use-flash-attention flag. This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
|
Huh. Okay, I'll make that conditional. |
just note for that people who will test it: you need to disable before installing reference: vladmandic/sdnext#3515 |
…I#7223 Signed-off-by: bigcat88 <bigcat88@icloud.com>
The construction using |
|
Can confirm... Wonder what AMD broke this time. Lemme have a look. |
|
Uh oh. Does anyone have Ping @dejay-vu? The upstream CK commit in https://github.com/ROCm/flash-attention/tree/howiejay/navi_support can no longer be found. |
|
Lucky: I found an old checkout with howiejay's branch. I pushed a fork of howiejay's CK work and also cherrypicked the fix to the ROCm 6.4 issue. This seems to work: Branch is at https://github.com/FeepingCreature/composable_kernel/tree/howiejayz/supports_all_arch |
Thank you for your efforts. This command can now be installed and used normally. However, after testing in the torch2.6 + rocm6.4 environment, its speed is slower than that of --use-pytorch-cross-attention. |
|
What sort of speeds are you seeing? On Pytorch nightly, I get 3.7it/s with Pytorch cross attention and 4it/s with Flash Attention on my 7900 XTX. |
Is there a nightly version of pytorch for ROCm6.4 now? I used pytorch2.6.0+ROCm6.4 officially provided by AMD. In the most basic flux Text-generated pictures workflow, I get 1.47s/it with Pytorch cross attention and 1.59s/it with Flash Attention on my 7900 XT, and in the flux everything migration workflow, I get 4.12s/it with Pytorch cross attention and 5.22s/it with Flash Attention on my 7900 XT |
|
Huh. Try SDXL so we're comparing the same thing? I don't know how Flux works. You can just use Pytorch nightly with ROCm 6.4, you don't have to use the "officially supported" ones. Usually it works. |
|
I could use this with a 3060 RTX? |
|
Probably! I don't know how much it'll do for you though. (Note you can't use the pip FlashAttention repos recommended here, those are AMD only. But with NVidia you can install the upstream FlashAttention instead.) |
Okay, I'll give it a try with the SDXL model later. Let me explain that in my previous environment, when using PyTorch 2.7 + ROCm 6.3, --use-flash-attention was faster than --use-pytorch-cross-attention, For ROCm 6.4, the speeds of these two have reversed. |
The PyTorch nightly version is only available for the ROCm 6.3 version. I couldn't find the one for ROCm 6.4. |
|
You can just use the one from Pytorch.org: Odd, I didn't observe any performance change at all. |
|
Jeez, yeah okay I'm at 3000Mhz. No wonder. |
Can you test if there is difference between flashattn and cross-attention ? since I can't compile it is it worth it? Have you tried sage attention? |
|
I reliably get 10% more on FA, but some people have reported differently. SageAttn does absolutely not support AMD (they write manual shaders). |
When compiling I get this errors: |
|
My 7900 XT running SDXL can only reach 1it/s, while your 7900 XTx can reach up to 5it/s, which is 5 times faster than mine, there shouldn't be such a big gap, should there? |
|
@Hakim3i Isn't that 1.5, not SDXL? It's loading SD1ClipModel. Re your build error:
That's a 6600XT I think? I think they forgot to handle that specific GPU, lol. I can try to patch it but no guarantee. @githust66 Do you mean 9700 XT? Yeah AMD is really really bad at providing timely support for their own consumer hardware. It took two years for the 7900 XTX to get supported. It'll probably get better maybe. |
|
@Hakim3i Try to build again now. |
Sorry silly me I have 2 GPU I forgot to hide it to build flash-attn, my integrated gpu don't want to display on ubuntu I didn't find a fix. Btw when I built flash-attn those two custom nodes stopped importing and me who tought they were broken under linux:
|
|
Wait I can't read. Yeah yours seems just sort of 10% faster than mine. I wonder if it's a GPU power limit thing, maybe the Windows drivers are more aggressive. Or just a better gpu? |
It might be GPU drivers I don't know for sure there is more It under WSL then Ubuntu but the problem with WSL it is a VM and when I run WAN2.1 it just dosn't work it use all my 64GB of ram. |
|
Can you check what your TDP limit is set to? Mine is at 339W by default. Monitoring software claims it's power limited. |
|
I can't recall having needed to install that manually. I don't know what's going on with that error. I think 339W PPT Sustained is the same setting I have... so we have the same value there. That doesn't explain it. Try just making a separate ComfyUI folder? Then you can run with only default nodes. |
Yep I can make copy but as you can see I have more it/s with flash-attn so I might just use it. |
|
I mean - yay? Mission accomplished? :) |
50% of the workflows on internet use WAS NODE SUITE I kinda need to get it working. |
|
Yeah I really don't know what's going on with that, maybe back flash_attn out for now. |
For easy use pip install diffusers==0.25.1 But my WAN workflow is no longer working getting also florence issue AND error on WAN: But the basic workflow still work so this flash-attn is just hit and miss. |
|
Okay I've updated to 25.04 and now I also get the rotary_emb issue, lol. I guess upstream switched to triton-based, I'mma go see if I can steal that commit. |
|
Pushed a fix maybe? Issue has gone away locally, no idea if the upstream code actually works. It looks pretty generic. |
So what I just need to recompile? |
|
Yep, rerun the pip. I'm just pushing to that branch. (For others: edit: Woah, nice. After the update to Ubuntu 25.04 I'm seeing FA+ |
That's great. |
|
I... can not tell you what is up with that, that error makes no sense. |
I have torch from amd repo if that clear things? |
I was able to use flashattn like this: It works and am getting 5.10 it/s with OC. And holy it is much faster I can make a 3 seconds video 480p in 160 seconds where with-use-pytorch-cross-attention it is about 230 seconds. Edit: Asked ChatGPT and it delivered so what I did: and I got the whl package then: |
|
Nice, congrats on the performance! That warning is harmless, don't worry about it. |
|
There's been two updates on this triton branch. Dao-AILab/flash-attention#2217 and Dao-AILab/flash-attention#2178 that ones already in. The other one still getting reviewed. flash attention 3 and infinity cache usage. Can you take a look? It's probably super easy for you to integrate that here. |











This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention. Without even using torch.compile, this can bring SDXL 1024x1024 from 5.5s to 5s.
(I did a bench over at https://www.reddit.com/r/buildapc/comments/1j2zfbv/9070_xt_vs_7900_xtx/mhkmv0x/ )