Skip to content

Add torch.compile support to flash attention 3#1769

Merged
v0i0 merged 18 commits intoDao-AILab:mainfrom
guilhermeleobas:guilhermeleobas/fa3-compile
Dec 4, 2025
Merged

Add torch.compile support to flash attention 3#1769
v0i0 merged 18 commits intoDao-AILab:mainfrom
guilhermeleobas:guilhermeleobas/fa3-compile

Conversation

@guilhermeleobas
Copy link
Copy Markdown
Contributor

Enable torch.compile support for FlashAttention and improve testing

  • Add support for torch.compile to recognize FlashAttention forward/backward functions.
  • Update tests to use torch._subclasses.fake_tensor.fake_check to validate the FakeTensor implementation.
  • Creates a file exposing the flags used to build FlashAttention at runtime (flash_attn_config.py)

@guilhermeleobas guilhermeleobas marked this pull request as ready for review July 22, 2025 20:52
@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

cc @zou3519 @anijain2305

}

return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here? Adding additional clones is generally bad for performance

hopper/build.sh Outdated
Comment on lines +5 to +8
# Flash Attention Minimal Build Script for PHI-1 Reproducer
# Uses subshell to automatically clean up environment variables

# Run in subshell - variables are automatically cleaned up when it exits
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have more context for this file? I don't see a mention of PHI-1 elsewhere

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.

Comment on lines +61 to +63
if not DISABLE_FAKE_CHECK:
flash_attn_func = run_fake_check(flash_attn_func)
flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func)
Copy link
Copy Markdown

@zou3519 zou3519 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit jank but I'm fine with it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tridao, are you ok with running fake_check as part of the tests? I can invert the flag to be opt-in instead of opt-out.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that slow down the tests much? Rn it takes 30-50mins to run all the tests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check, one should enable the flag.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo it would be good to have at least some of the tests on by default to prevent bit-rotting. how about instead doing a OPCHECK_FREQ instead of ENABLE where 1 means every test runs with it, and 100 means every 100s test is run, and defaulting it to 100? that would increase testing currently by no more than 1%.

fake_check(fn, args, kwargs)
return fn(*args, **kwargs)
return wrapper

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess @tridao would be a better person to answer this one.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do test a lot of input shapes and different attn options (~100k tests iirc)

Copy link
Copy Markdown

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, modulo the prints

@lantudou
Copy link
Copy Markdown

lantudou commented Aug 5, 2025

@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code.
`
import torch
from flash_attn_interface import flash_attn_func, _flash_attn_forward
from torch import nn
class EfficienctMultiHeadAttention(nn.Module):

def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True):
    """
    初始化。

    参数:
        embed_size (int): 输入和输出的特征维度。
        num_heads (int): 注意力头的数量。
        dropout (float): 应用于注意力权重的 dropout 比率。
        use_flash_attn (bool): 是否尝试使用 FlashAttention (如果可用)。
    """
    super().__init__()
    assert embed_size % num_heads == 0, "embed_size 必须能被 num_heads 整除"

    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = embed_size // num_heads
    self.use_flash_attn = use_flash_attn and (flash_attn_func is not None)

    # 使用一个线性层同时生成 Q, K, V,效率更高
    self.qkv_proj = nn.Linear(embed_size, 3 * embed_size)
    # 最终输出的线性层
    self.out_proj = nn.Linear(embed_size, embed_size)
    self.dropout = dropout

def forward(self, x, attention_mask=None):
    """
    前向传播。

    参数:
        x (torch.Tensor): 输入张量,形状为 (N, seq_length, embed_size)
        attention_mask (torch.Tensor, optional): 注意力掩码。
            对于 SDPA,布尔掩码中 True 的位置表示 *不* 被关注。
            对于 FlashAttention,通常使用 causal 参数,而非 mask。

    返回:
        torch.Tensor: 输出张量,形状为 (N, seq_length, embed_size)
    """
    N, seq_length, _ = x.shape

    # 1. 投影并切分 Q, K, V
    # (N, seq_length, embed_size) -> (N, seq_length, 3 * embed_size)
    qkv = self.qkv_proj(x)
    # (N, seq_length, 3 * embed_size) -> 3 x (N, seq_length, embed_size)
    q, k, v = qkv.chunk(3, dim=-1)

    # 2. 重塑 Q, K, V 以便多头计算
    # (N, seq_length, embed_size) -> (N, seq_length, num_heads, head_dim)
    q = q.view(N, seq_length, self.num_heads, self.head_dim)
    k = k.view(N, seq_length, self.num_heads, self.head_dim)
    v = v.view(N, seq_length, self.num_heads, self.head_dim)

    # --- 使用 FlashAttention 的路径 ---
    if self.use_flash_attn and attention_mask is None:
        # flash_attn_func 需要的形状是 (batch, seqlen, nheads, headdim)
        # 这与我们当前的形状完全匹配
        # flash_attn_func 的 dropout 在内部处理
        #print("正在使用 FlashAttention...")
        out = flash_attn_func(
            q, k, v
        )
    # 3. 合并多头的输出
    # (N, seq_length, num_heads, head_dim) -> (N, seq_length, embed_size)
    out = out.reshape(N, seq_length, self.embed_size)

    # 4. 通过最后的线性层
    out = self.out_proj(out)

    return out

batch_size = 16

sequence_length = 256
embedding_dim = 2048

test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16()
input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16()
test = torch.compile(test, mode='max-autotune')

out = test(input_tensor)
loss = out.sum()
loss.backward()
`

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

Hi @lantudou, thanks for the reproducer. Could you try it again?

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

@tridao, could you take a look one more time once you have some cycles to spare?

@m3rcuriel
Copy link
Copy Markdown

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

I think the torch custom ops should be registered to a flash_attn_3 namespace to match the TORCH_LIBRARY registration in C. Would these conflict with FA2?

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

@m3rcuriel
Copy link
Copy Markdown

...

Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3?

idk what the future plan is but at least right now transformers thinks it can have both at the same time

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

idk what the future plan is but at least right now transformers thinks it can have both at the same time

Got it. I changed the namespace to flash_attn_3.

@Tomcli
Copy link
Copy Markdown

Tomcli commented Aug 20, 2025

Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop.

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

guilhermeleobas commented Aug 21, 2025

Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use register_autograd as well. I can do this in a follow-up PR if needed.

Edit: added in this one.

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

Hi @Tomcli, could you try the last commit, please?

@OutofAi
Copy link
Copy Markdown

OutofAi commented Aug 30, 2025

thanks for fixing it @guilhermeleobas guilhermeleobas, the compile now works, but the compiled artifacts still breaks me for when using AOTInductor packaging.

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

Thanks for trying this PR @OutofAi. Do you have a reproducer for this error?

@Turakar
Copy link
Copy Markdown

Turakar commented Sep 2, 2025

I just want to share that for my workflow, based only on torch.compile() and not torch.export(), this PR works. Thanks a lot!

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

@v0i0 I was able to reproduce this error without torch.compile. I changed the code to run flash_attn_varlen_func twice, and the grad values differ between runs. Since no compilation is involved, can we assume this issue is independent of the changes introduced here?

Below are the changes I added to the test file:

def copy_inputs(args, kwargs):
    new_args = [arg.clone().detach().requires_grad_(arg.requires_grad) if isinstance(arg, torch.Tensor) else arg for arg in args]
    new_kwargs = {k: v.clone().detach().requires_grad_(v.requires_grad) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
    return new_args, new_kwargs


def run_forward_backward(f, args, kwargs):
    args_, kwargs_ = copy_inputs(args, kwargs)
    out = f(*args_, **kwargs_)
    sm = out.sum().abs()
    diff_args = [arg for arg in args_ if isinstance(arg, torch.Tensor) and arg.requires_grad]
    out_grad = torch.autograd.grad(sm, diff_args, allow_unused=True)
    return out, out_grad


def run_opcheck(fn):
    def wrapper(*args, **kwargs):
        if should_run_schema_check(args, kwargs):
            safe_schema_check(fn, args, kwargs)

        if should_run_fake_check(args, kwargs):
            safe_fake_check(fn, args, kwargs)

        if should_test_backward(args, kwargs):
            # Here fn is flash_attn_func or flash_attn_varlen_func
            # and we run forward and backward twice to compare grads
            out, out_grad = run_forward_backward(fn, args, kwargs)
            out_2, out_grad_2 = run_forward_backward(fn, args, kwargs)
            torch.testing.assert_close(out, out_2)
            torch.testing.assert_close(out_grad, out_grad_2)
        return fn(*args, **kwargs)
    return wrapper

@v0i0
Copy link
Copy Markdown
Collaborator

v0i0 commented Nov 20, 2025

Hi @v0i0. I'm starting to think that this failure is not actually related to the changes made in this PR (support torch.compile).

It is only reproducible when add_unused_qkv=True and attention_chunk=0 are set. I manually verify the dtype, shape, strides, layout, and device of the tensors generated on eager and torch.compile and they match with the ones produced on eager. I'll spend some more time on it later this week.

i think it's expected that bwd is not exactly the same between runs for FA. I still see errors of the shape

RuntimeError: flash_attn_3::_flash_attn_forward() Expected a value of type 'int' for argument 'window_size_left' but instead found type 'FunctionalTensor'.

e.g. in test_flash_attn.py::test_flash_attn_varlen_output[2048-2048-64-True-False-True-15.0-False-False-mha-dtype0]

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

@v0i0, I couldn't reproduce the error you posted. I'm using a build of FA3 with all flags enabled.

$ FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK=FALSE FLASH_ATTENTION_ENABLE_OPCHECK=TRUE python -m pytest test_flash_attn.py --tb=short  -k "test_flash_attn_varlen_output[2048-2048-64-True-False-True-15.0-False-False-mha-dtype0]" -rs
================================================================================ test session starts ================================================================================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/git/flash-attention/hopper
collected 98751 items / 98750 deselected / 1 selected

test_flash_attn.py .                                                                                                                                                          [100%]

======================================================================= 1 passed, 98750 deselected in 10.76s ========================================================================

Could you share the stacktrace? And are you building FA with which flags?

@v0i0
Copy link
Copy Markdown
Collaborator

v0i0 commented Nov 22, 2025

@v0i0, I couldn't reproduce the error you posted. I'm using a build of FA3 with all flags enabled.

$ FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK=FALSE FLASH_ATTENTION_ENABLE_OPCHECK=TRUE python -m pytest test_flash_attn.py --tb=short  -k "test_flash_attn_varlen_output[2048-2048-64-True-False-True-15.0-False-False-mha-dtype0]" -rs
================================================================================ test session starts ================================================================================
platform linux -- Python 3.12.11, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/git/flash-attention/hopper
collected 98751 items / 98750 deselected / 1 selected

test_flash_attn.py .                                                                                                                                                          [100%]

======================================================================= 1 passed, 98750 deselected in 10.76s ========================================================================

Could you share the stacktrace? And are you building FA with which flags?

i just tried with your latest changes & pytorch nightly, and i can't reproduce it. so i think we're good to go :-)

@v0i0 v0i0 self-requested a review November 22, 2025 00:05
@v0i0
Copy link
Copy Markdown
Collaborator

v0i0 commented Nov 22, 2025

@tridao any objections merging this

@vadimkantorov
Copy link
Copy Markdown

vadimkantorov commented Dec 26, 2025

@guilhermeleobas Does FAv3 now work under fullgraph too? And same question about FAv2:

I've also found an interesting case: for torch.func.grad to work on FA-enabled model, this fix was needed:

Maybe worth adding such a torch.func.grad test into FA codebase too... to check for correctness of torch.Library / registrations

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

@guilhermeleobas Does FAv3 now work under fullgraph too?

This PR is FA3 specific (code under /hopper folder). And I think it should work with fullgraph=True.

Maybe worth adding such a torch.func.grad test into FA codebase too... to check for correctness of torch.Library / registrations

We do have some tests for that in test_flash_attn.py. The check is quite expensive and it is only executed if ENABLE_AUTOGRAD_CHECK env var is enabled.

if should_test_backward(args, kwargs):
# Expensive check
safe_aot_autograd_check(fn, args, kwargs, dynamic=False)
safe_aot_autograd_check(fn, args, kwargs, dynamic=True)

@janeyx99
Copy link
Copy Markdown
Contributor

janeyx99 commented Jan 8, 2026

@guilhermeleobas It looks like after this change, the test_flash3_bw_compatibility is breaking BC compatibility on this line: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L1243. In particular, the return types of the bwd operator are changed. cc @v0i0 @tridao whether this is a concern or whether we should just accept the BC break and modify the schema in the test.

Test output:

Details
test_flash_attn.py::test_flash3_bw_compatibility FAILED

========================== FAILURES ===========================
________________ test_flash3_bw_compatibility _________________

    def test_flash3_bw_compatibility() -> None:
        # Let's try to always stay backward compatible! This will make life easier
        # for downstream libaries, users, and exported models.
        # 1/ Instead of removing arguments, error out if their value is no longer supported
        # 2/ When adding arguments, add them at the end with a default value
        assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema(
            "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, "
            "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, "
            "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, "
            "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, "
            "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, "
            "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, "
            "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, "
            "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, "
            "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, "
            "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) "
            "-> (Tensor(out!), Tensor, Tensor, Tensor)"
        ))
>       assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema(
            "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, "
            "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, "
            "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, "
            "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, "
            "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) "
            "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"
        ))
E       AssertionError: assert False
E        +  where False = is_backward_compatible_with(flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor))
E        +    where is_backward_compatible_with = flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) -> (Tensor, Tensor, Tensor, Tensor, Tensor).is_backward_compatible_with
E        +      where flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) -> (Tensor, Tensor, Tensor, Tensor, Tensor) = <OpOverload(op='flash_attn_3.bwd', overload='default')>._schema
E        +        where <OpOverload(op='flash_attn_3.bwd', overload='default')> = <OpOverloadPacket(op='flash_attn_3.bwd')>.default
E        +          where <OpOverloadPacket(op='flash_attn_3.bwd')> = <module 'torch.ops.flash_attn_3' from 'torch.ops'>.bwd
E        +            where <module 'torch.ops.flash_attn_3' from 'torch.ops'> = <module 'torch.ops' from '_ops.py'>.flash_attn_3
E        +              where <module 'torch.ops' from '_ops.py'> = torch.ops
E        +    and   flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor) = parse_schema('flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)')

test_flash_attn.py:1243: AssertionError

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

guilhermeleobas commented Jan 8, 2026

Hi @janeyx99, thanks for flagging this. This is indeed a backward-incompatible change: I had to update the mha_bwd signature to make it work with torch.compile. mha_bwd returns some of its inputs (dq, dk, dv), and this pattern causes some issues with torch.compile. I can update the test, depending on the decision of @v0io and @tridao.

dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
softmax_d, *rest = flash_attn_3_cuda.bwd(
dout,

I also believe this change should have minimal user impact, since _flash_attn_backward is not called directly by users and is instead invoked via FlashAttnQKVPackedFunc, FlashAttnFunc, or FlashAttnVarlenFunc

edit: Created #2153 but will wait for @v0i0 and @tridao decision on this.

@varunneal
Copy link
Copy Markdown

Is backward working for flash_attn_varlen_function? I'm getting issues:

[rank0]:           ^^^^^^^^^^^^^^^^^    
[rank0]:   File "/tmp/torchinductor_root/cq/ccqgytym5tecofgnv4jkt7tn673rxfkennynvxwquy4in5znm3yw.py", line 8366, in call                                        
[rank0]:     assert_size_stride(buf35, (16384, 6, 128), (768, 128, 1), 'torch.ops.flash_attn_3._flash_attn_backward.default')                                   
[rank0]: AssertionError: wrong number of dimensions2 for op: torch.ops.flash_attn_3._flash_attn_backward.default  

I'm on the main branch of this repo.

@janeyx99
Copy link
Copy Markdown
Contributor

@varunneal did you also build FA3 on main?

@varunneal
Copy link
Copy Markdown

@janeyx99 Yes I built from the hopper repo on main via these commands

git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention/hopper

export MAX_JOBS=16
export FLASH_ATTENTION_FORCE_BUILD=TRUE   
export FLASH_ATTENTION_DISABLE_SM80=TRUE   
export FLASH_ATTENTION_DISABLE_FP16=TRUE

python setup.py bdist_wheel

@guilhermeleobas
Copy link
Copy Markdown
Contributor Author

@varunneal do you have a reproducer?

@varunneal
Copy link
Copy Markdown

@guilhermeleobas @janeyx99 I found the mistake, it was on my end. Sorry about that, thanks

elewarr pushed a commit to elewarr/flash-attention that referenced this pull request Feb 4, 2026
…s/fa3-compile

Add torch.compile support to flash attention 3
YangWang92 pushed a commit to YangWang92/flash-attention that referenced this pull request Feb 15, 2026
…s/fa3-compile

Add torch.compile support to flash attention 3
Strivin0311 pushed a commit to Strivin0311/flexible-flash-attention that referenced this pull request Feb 27, 2026
…s/fa3-compile

Add torch.compile support to flash attention 3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.