-
Notifications
You must be signed in to change notification settings - Fork 588
enable flash attention for image generation #1633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable flash attention for image generation #1633
Conversation
|
Does it work well? Any speedups/memory savings? I haven't tried SD flash attn extensively.
And just note any potential incompatibilities. Then we should be good to merge. |
|
A quick test with SDXL + DMD2 LoRA rendering a 1024x1536 image (not including the VAE phase):
GPU is a 7600 XT; it's one of the few that gets better LLM token generation speed with flash attention on, so it may be kind of a best case. A quirk: I'm getting changed images in some cases, when turning on flash attention. Not hugely so: the composition stays the same, and the quality seems fine; it's mostly like the variations I get when changing between different backends. |
|
I did a few more tests on SDXL+Vulkan: in general, higher resolutions benefit more from flash attention. At 1024x1024, I'm getting around 500M VRAM / 18.0s, versus 900M VRAM / 21.5s; at 512x512 and 512x768, essentially the same ~200M VRAM, and slightly faster inference (~4-8%). SD1.5 doesn't seem to benefit from it, even when rendering an 832x832 image: same VRAM usage and speed. So... pretty much confirmed that it's working according to @Green-Sky 's description of the --diffusion-fa PR (which I just found 🙂). |
|
Oh cool. Last time I checked only cuda got faster. Must have been upstream ggml updates. Also, iirc sd1.x does not make use of it (yet). Might remember this wrong though. |
FA requires F16, so something is potentially cast down (or up?) |
It is currently ONLY for the diffusion model. Yes some large dimensions don't work, but thats a ggml FA thing. (past 1024x1024 somewhere, I don't remember) |
Interesting... I did measure the same memory gains on quantized models, so I guess they're indeed being converted to F16 on demand. BTW, you mentioned that you saw no benefit for the VAE stage. Any chance of that being different now, with the newer backend implementations? |
I would have to check what the issue was, but there where a lot of cases that needed padding. Whether or not something makes sense depends very much on the shape of the tensors. |
LostRuins
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm then
This simply applies the existing config to the image generation parameters.
Tested on SD1.5, SDXL and Flux, both on ROCm and Vulkan.