Skip to content

Conversation

@guilhermeleobas
Copy link
Contributor

@guilhermeleobas guilhermeleobas commented Jul 25, 2025

I'm not 100% sure if the torch compile tests are representative of a common SageAttention usage. I wrote them based on the benchmark files.

Also, is there any lint rule that I should apply to the files?

Edit: I just saw there's a beta version of SageAttention 3 on HuggingFace. Is the code available? If so, I can also work on add torch.compile support on it.

I'm not so sure about the tests that I added, if it is representative of
the common usage of SageAttention
@guilhermeleobas
Copy link
Contributor Author

cc @StrongerXi

@StrongerXi
Copy link

@jt-zhang @jason-huang03 would it be possible to merge this patch? It's pretty harmless, most of the code is adding fake impls for the sage attention ops (you can think of them as registering shape functions for torch.compile).

It actually fixes some accuracy issue which showed up in sage + compile: comfyanonymous/ComfyUI#8689 (comment).

Also, diffusers is adding Sage Attention as a backend: huggingface/diffusers#12439, so merging this patch would benefit many users who often use sage + compile in diffusers and comfyui.

@jt-zhang
Copy link
Member

jt-zhang commented Oct 6, 2025

Thank you for your pr. Please @whx1003 help to check and merge this pr.

@sayakpaul
Copy link

Hey folks!

I am one of the maintainers of diffusers and can vouch for the merit of this PR as it benefits the torch.compile users quite a bit.

This would be great to have merged. Also cc: @MekkCyber

@whx1003
Copy link
Collaborator

whx1003 commented Oct 7, 2025

@guilhermeleobas Thanks for the PR! The code looks good to me — I ran tests on a RTX5090 and everything worked as expected.

Could you please keep only the changes under the sageattention/ directory in this PR? The other files seem unrelated.

@sayakpaul
Copy link

@whx1003 tests/test_torch_compile.py could be beneficial for compilation tests. WDYT?

@whx1003
Copy link
Collaborator

whx1003 commented Oct 7, 2025

@sayakpaul I agree that the tests could be useful, but we don’t maintain test files in this repo at the moment.

I’d prefer to leave them out of this PR and maybe revisit adding them later.

@sayakpaul
Copy link

No worries! @guilhermeleobas maybe we can have the test file as a gist under your profile and mention a link.

@guilhermeleobas
Copy link
Contributor Author

Thanks for the feedback folks. @whx1003 I've removed the test file and unrelated changes.

@whx1003 whx1003 merged commit 15c0e22 into thu-ml:main Oct 8, 2025
@whx1003
Copy link
Collaborator

whx1003 commented Oct 8, 2025

Thanks!

@woct0rdho
Copy link

nitpick: If I understand correctly, in the fake impl:

lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda")

it's better to set lse's device to query.device rather than "cuda" (the 0-th device), so it's consistent with the C code:
lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32));

This is also true for sm89 and sm90. Maybe this helps solve the issue of set_device:
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)

but I haven't tested.

@woct0rdho
Copy link

Another thing noteworthy, although I haven't tested, is that we may need to set mutates_args={"output"}

If I understand correctly, this should be absolutely needed, but I couldn't yet find an example that gives correct output with this and wrong output without this.

In #74 (comment) , he succeeded by specifying q, k, v as mutable parameters, but I think that makes no sense...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants