Add torch.compile support to flash attention 3#1769
Conversation
hopper/flash_api.cpp
Outdated
| } | ||
|
|
||
| return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; | ||
| return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; |
There was a problem hiding this comment.
What's going on here? Adding additional clones is generally bad for performance
hopper/build.sh
Outdated
| # Flash Attention Minimal Build Script for PHI-1 Reproducer | ||
| # Uses subshell to automatically clean up environment variables | ||
|
|
||
| # Run in subshell - variables are automatically cleaned up when it exits |
There was a problem hiding this comment.
do you have more context for this file? I don't see a mention of PHI-1 elsewhere
There was a problem hiding this comment.
I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.
hopper/test_flash_attn.py
Outdated
| if not DISABLE_FAKE_CHECK: | ||
| flash_attn_func = run_fake_check(flash_attn_func) | ||
| flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) |
There was a problem hiding this comment.
@tridao, are you ok with running fake_check as part of the tests? I can invert the flag to be opt-in instead of opt-out.
There was a problem hiding this comment.
would that slow down the tests much? Rn it takes 30-50mins to run all the tests
There was a problem hiding this comment.
I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check, one should enable the flag.
There was a problem hiding this comment.
imo it would be good to have at least some of the tests on by default to prevent bit-rotting. how about instead doing a OPCHECK_FREQ instead of ENABLE where 1 means every test runs with it, and 100 means every 100s test is run, and defaulting it to 100? that would increase testing currently by no more than 1%.
| fake_check(fn, args, kwargs) | ||
| return fn(*args, **kwargs) | ||
| return wrapper | ||
|
|
There was a problem hiding this comment.
Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?
There was a problem hiding this comment.
I guess @tridao would be a better person to answer this one.
There was a problem hiding this comment.
we do test a lot of input shapes and different attn options (~100k tests iirc)
|
@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code. batch_size = 16 sequence_length = 256 test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16() out = test(input_tensor) |
|
Hi @lantudou, thanks for the reproducer. Could you try it again? |
|
@tridao, could you take a look one more time once you have some cycles to spare? |
|
I think the torch custom ops should be registered to a |
Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3? |
idk what the future plan is but at least right now transformers thinks it can have both at the same time |
Got it. I changed the namespace to |
|
Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop. |
|
Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use Edit: added in this one. |
|
Hi @Tomcli, could you try the last commit, please? |
871a3cf to
e52508f
Compare
|
thanks for fixing it @guilhermeleobas guilhermeleobas, the compile now works, but the compiled artifacts still breaks me for when using AOTInductor packaging. |
|
Thanks for trying this PR @OutofAi. Do you have a reproducer for this error? |
|
I just want to share that for my workflow, based only on |
|
@v0i0 I was able to reproduce this error without torch.compile. I changed the code to run flash_attn_varlen_func twice, and the grad values differ between runs. Since no compilation is involved, can we assume this issue is independent of the changes introduced here? Below are the changes I added to the test file: def copy_inputs(args, kwargs):
new_args = [arg.clone().detach().requires_grad_(arg.requires_grad) if isinstance(arg, torch.Tensor) else arg for arg in args]
new_kwargs = {k: v.clone().detach().requires_grad_(v.requires_grad) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return new_args, new_kwargs
def run_forward_backward(f, args, kwargs):
args_, kwargs_ = copy_inputs(args, kwargs)
out = f(*args_, **kwargs_)
sm = out.sum().abs()
diff_args = [arg for arg in args_ if isinstance(arg, torch.Tensor) and arg.requires_grad]
out_grad = torch.autograd.grad(sm, diff_args, allow_unused=True)
return out, out_grad
def run_opcheck(fn):
def wrapper(*args, **kwargs):
if should_run_schema_check(args, kwargs):
safe_schema_check(fn, args, kwargs)
if should_run_fake_check(args, kwargs):
safe_fake_check(fn, args, kwargs)
if should_test_backward(args, kwargs):
# Here fn is flash_attn_func or flash_attn_varlen_func
# and we run forward and backward twice to compare grads
out, out_grad = run_forward_backward(fn, args, kwargs)
out_2, out_grad_2 = run_forward_backward(fn, args, kwargs)
torch.testing.assert_close(out, out_2)
torch.testing.assert_close(out_grad, out_grad_2)
return fn(*args, **kwargs)
return wrapper |
i think it's expected that bwd is not exactly the same between runs for FA. I still see errors of the shape e.g. in |
|
@v0i0, I couldn't reproduce the error you posted. I'm using a build of FA3 with all flags enabled. $ FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK=FALSE FLASH_ATTENTION_ENABLE_OPCHECK=TRUE python -m pytest test_flash_attn.py --tb=short -k "test_flash_attn_varlen_output[2048-2048-64-True-False-True-15.0-False-False-mha-dtype0]" -rs
================================================================================ test session starts ================================================================================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/git/flash-attention/hopper
collected 98751 items / 98750 deselected / 1 selected
test_flash_attn.py . [100%]
======================================================================= 1 passed, 98750 deselected in 10.76s ========================================================================Could you share the stacktrace? And are you building FA with which flags? |
i just tried with your latest changes & pytorch nightly, and i can't reproduce it. so i think we're good to go :-) |
|
@tridao any objections merging this |
|
@guilhermeleobas Does FAv3 now work under fullgraph too? And same question about FAv2: I've also found an interesting case: for torch.func.grad to work on FA-enabled model, this fix was needed: Maybe worth adding such a torch.func.grad test into FA codebase too... to check for correctness of torch.Library / registrations |
This PR is FA3 specific (code under
We do have some tests for that in flash-attention/hopper/test_flash_attn.py Lines 101 to 104 in bb2efb3 |
|
@guilhermeleobas It looks like after this change, the Test output: Details |
|
Hi @janeyx99, thanks for flagging this. This is indeed a backward-incompatible change: I had to update the flash-attention/hopper/flash_attn_interface.py Lines 270 to 272 in 3e87e42 I also believe this change should have minimal user impact, since edit: Created #2153 but will wait for @v0i0 and @tridao decision on this. |
|
Is backward working for I'm on the main branch of this repo. |
|
@varunneal did you also build FA3 on main? |
|
@janeyx99 Yes I built from the hopper repo on main via these commands |
|
@varunneal do you have a reproducer? |
|
@guilhermeleobas @janeyx99 I found the mistake, it was on my end. Sorry about that, thanks |
…s/fa3-compile Add torch.compile support to flash attention 3
…s/fa3-compile Add torch.compile support to flash attention 3
…s/fa3-compile Add torch.compile support to flash attention 3
Enable torch.compile support for FlashAttention and improve testing
flash_attn_config.py)