Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xformers 25 removes Triton fmha and completely breaks stable-fast, with fix (MemoryEfficientAttentionTritonFwdFlashBwOp, TritonFlashAttentionOp) #135

Closed
tau0-deltav opened this issue Mar 16, 2024 · 0 comments · Fixed by #136

Comments

@tau0-deltav
Copy link

tau0-deltav commented Mar 16, 2024

xformers==0.25.0 is released to pypi, for torch==2.2.1. if you (or the webUI you're using) calls sfast with xformers on, it doesn't work! There's an error about a missing function, and the exception leaves us somewhere far away and unhappy.

tl;dr: go to stable-fast/src/sfast/libs/xformers/xformers_attention.py and delete lines 14 through to 16 (the ones which refer to triton ops). pip install -e ../stable-fast/

i broke my back writing this tutorial for the completely lost. it's not good but:

how to install stable-fast and python packages in general from source because no one explains this stuff. step -1: ignore this and follow the readme. if it's not working: step 0: use linux, not WSL. WSL slows things down. you're using stable-fast to go fast. you need [fork()](https://en.wikipedia.org/wiki/Fork_(system_call)) not to mention BTRFS and co.). Try to avoid ubuntu (buggy slow packages for cuda and browsers) - pop os sounded good? I can't recommend my own distro for epistemic reasons. Tumbleweed is similar though.

rest went in to a gist here it got very long, it is not information-dense and is probably wrong. oh well, didn't need this evening anyway.

This long report for a very simple issue is mostly aimed at people trying to fix this themselves

Because the README is looking a little 'this is the final release' right now and sfast is such an outrageously performant optimisation for sd.next that the loss is unacceptable. (compiling with inductor takes so long, i'd gotten used to changing the dimensions of my output every few images - the overhead now comes from sd.next's reluctance to reload the model without re-loading the weights - a whole second!)

stable-fast is not compatible with xformers since the removal of xformers.ops.MemoryEfficientAttentionTritonFwdFlashBwOp
from this commit: one way to get stable fast working then is to use the prior revision of xformers git. or... just use xformers==0.24.0

The changelog (linked above) doesn't explain that this is motivated by improvements in flash attention 2, which now supports some size of attention operation that it previously didn't. I don't actually understand this stuff you see.

so if you go to xformers_attention.py

from typing import Optional
import torch
from xformers.ops import (memory_efficient_attention, AttentionOp)
from xformers import ops
from sfast.utils.custom_python_operator import register_custom_python_operator

OP_STR_MAP = {
    ops.MemoryEfficientAttentionCutlassFwdFlashBwOp:
    'MemoryEfficientAttentionCutlassFwdFlashBwOp',
    ops.MemoryEfficientAttentionCutlassOp: 'MemoryEfficientAttentionCutlassOp',
    ops.MemoryEfficientAttentionFlashAttentionOp:
    'MemoryEfficientAttentionFlashAttentionOp',
    ops.MemoryEfficientAttentionOp: 'MemoryEfficientAttentionOp',
    ops.MemoryEfficientAttentionTritonFwdFlashBwOp: #14
    'MemoryEfficientAttentionTritonFwdFlashBwOp', #15
    ops.TritonFlashAttentionOp: 'TritonFlashAttentionOp', #16
}

STR_OP_MAP = {v: k for k, v in OP_STR_MAP.items()}
...more code mumbo-jumbo... 

and you delete lines 14 15 16 that little dictionary horror at the bottom there no longer makes reference to methods which don't exist - stable-fast correctly doesn't specify which exact attention optimisations xformers should use, and xformers, in principle, automatically picks for you.

oh and python uses k and v internally for dictionary keys and values. no relation to the attention

OP_STR_MAP = {
    ops.MemoryEfficientAttentionCutlassFwdFlashBwOp:
    'MemoryEfficientAttentionCutlassFwdFlashBwOp',
    ops.MemoryEfficientAttentionCutlassOp: 'MemoryEfficientAttentionCutlassOp',
    ops.MemoryEfficientAttentionFlashAttentionOp:
    'MemoryEfficientAttentionFlashAttentionOp',
    ops.MemoryEfficientAttentionOp: 'MemoryEfficientAttentionOp',
}

Why isn't this a pull request?

  1. I have no idea what I'm doing. The last time I worked with numeric computation I wrote a particle swarm optimiser. It was pretty cool. The wikipedia article on that was longer than the one for SGD back then. I was a teenager, doing multithreaded C#! it worked in 10 dimensions! - I'd peaked. Last complex computer program I ever wrote. I can barely read python now. I can't write hello world to a file.
    I don't really understand how this patch affects sfast, just that it's compatible. It might well not be faster than not using xformers - i've not used sfast outside of sd.next and I wouldn't know.
    I don't know if I perhaps should have specified using the new flash attention calls somehow in place of Triton but the CUTLASS version is 6 months old so we can assume it still can't do what it needed triton for 6 months ago today.
  2. the last time I tried to use git to do something trivial that wasn't reset the entire working tree, i nearly cried. I don't know how to do a pull request and my therapist (a sycophantic mixtral finetune) thinks I shouldn't have to find out.

*honestly in a straight line the H100 is what, 30% faster than a 3090? That's until you try to run VTOL VR of course. The H100 just can't cope.

Thank you very much for this very cool code chengzeyi. I (very seriously) wish I were as clever as you.
I'm going to go smoothie my brain with some Julia now instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant