Add native gfx12 attention backend#368
Conversation
|
@0xDELUXA if possible can you try this PR branch on comfyui ? |
Sure! I’m AFK right now, but I’ll try it later. Looking great so far! |
|
I am using this implementation of SageAttention with the Z-Image Turbo model on a 9070XT GPU. Issue: |
Thanks for checking! does this happen with sageattention v1 too? or is it specific to this PR? |
|
@trfmk123 just pushed a fix. Can you try again? |
Just tested it, and the image outputs for Z-Image Turbo are completely normal now. |
Thanks! Perf has regressed a bit due to the fix, but getting the quality right is the main thing. I'll try to do a perf pass. |
I suggest implementing v_scale scaling within the kernel. This would be a great way to fix the desaturation/precision issues without sacrificing any performance. |
Good point. The gfx12 code currently takes raw fp8 weights without scaling which isn't desirable. |
|
I'm using ROCm |
Did you install the "rocm-sdk-devel" pip wheel? it's part of the |
I used this command to install: |
Can you try |
Actually, this one goes into an endless loop, so don't try that. Do |
This installs ROCm |
Yeah that should work. I use the same wheels. |
|
No errors when building now. But these wheels are close to a month old. Don’t we have any newer ones that are “stable”? |
|
They're in the index but not being picked up for some reason. These ones are from may 11 |
|
I don't really understand what you meant by: in #368 (comment). AFAIK, if xformers isn't installed as a standalone package, ComfyUI won't print that message. For me, regardless of the attention backend used (SDPA, Sage, Flash), it always prints Could this be happening on your end because you're also working on ROCm/xformers#87? |
|
I also have a large FLUX.2 FP8 multi-stage workflow, and the only change I make is swapping Flash attention for Sage attention. Across multiple runs, runtime improves from 240 s/it to 60 s/it. At first, I was skeptical that this was actually the case, but it turned out to be true. [ |
Hello everyone. There is a problem with the runners running in the CI when running smoke tests. Alredy reported and Scott took a look at it. Ive been using v2-staging for almost 3 weeks for per family releases. @jammm Ive been using multi arch releases to test them too. Will try to build and test against multi arch releases later today. Any specific script or test ? Or a general run is okay ? |
|
I'm running tests now with this version to see how it compares against SA v1 that I was using before, will report my findings as soon as I have results. Out of curiosity, is there a way to verify that the SA v2 path is actually used? |
|
Hi everyone, I've provided a simple speed test script here: Google Drive Link I tested this on an AMD 9070XT using the Multi-arch releases (untested packages). My environment details: PyTorch: 2.12.0+rocm7.13.0a20260514 ROCm: 7.13.26192
Hope this script helps! |
|
FYI this is a crazy speed boost on my r9700 pro. Reducing LTX and WAN render times down by considerable amounts. LTX 2.3 736x480 20 seconds WAN 2.2 736x480 5 seconds JuggernautXL down to 8s a render with 30 steps. Bravo to those who worked on this. I'm happy to share logs or any data anyone wants. I'm running on an r9700 pro. FYI I power restrict my card at 200w so many users will have better numbers. |
|
9070XT GPU After testing under Linux, it was found that the speed slowdown began with the update "6b71b93Support for native sequence tails in graphics 12". Subsequent updates did not show any further significant slowdown |
My findings: SageAttention v1 - 281.80 seconds - clip encode (x2) + diffusion (x2) + vae (x2) My workflow is a high-low pass setup that uses 2x SamplerCustomAdvanced, each with it's own Clip Text Encoder and it's own vae for preview/saving results. This was only a single generation (identical on both cases! Only variable that changed between the two was SageAttention v1 to v2) comparison, so not entirely conclusive yet. But the results are impressive if they hold up like this. Doing the math on this single comparison run there's a 35% improvement in generation performance running my full workflow. |
@trfmk123 Which FlashAttention version are you using? I assume it isn't the latest one from aiter. Are you using v2.8.3? |
According to the PR description, it prioritizes this implementation over Triton on RDNA4:
|
my setup is running Flash Attention 2.8.4 via the composable_kernel (ck) backend. |
|
Quick script I used to bench this native Sage 2.2.0 implementation against Flash 2.8.4 (aiter triton) on Windows. Results: (venv) PS C:\> python bench_sage2_gfx12_vs_fa2.py
Device : AMD Radeon RX 9060 XT (gfx1200)
Config : dtype=fp16 causal=False B=1 Hq=32 Hkv=32 warmup=50 iters=200
Loading implementations …
[sage] using sageattn_qk_int8_pv_gfx12_native (direct)
C:\ComfyUI\venv\Lib\site-packages\flash_attn\flash_attn_interface.py:17: UserWarning: flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation
warnings.warn("flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation")
[aiter] Windows: CK and HIP ops are not available. Triton ops only.
[fa2] using flash_attn.flash_attn_func
S D Sage ms FA2 ms Sage TF/s FA2 TF/s Speedup
────────────────────────────────────────────────────────────────────
512 64 0.240 0.320 8.96 6.71 1.335x
1024 64 0.391 0.567 21.95 15.15 1.449x
2048 64 0.710 1.284 48.37 26.75 1.808x
4096 64 2.110 4.218 65.14 32.58 1.999x
8192 64 7.314 16.832 75.16 32.66 2.301x
512 128 0.290 0.398 14.81 10.80 1.371x
1024 128 0.603 0.797 28.47 21.55 1.322x
2048 128 1.444 2.117 47.59 32.46 1.466x
4096 128 4.216 8.041 65.20 34.19 1.907x
8192 128 14.962 31.508 73.49 34.90 2.106x
[sage] S=1024 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
1024 256 n/a 1.654 n/a 20.77 n/a
[sage] S=2048 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
2048 256 n/a 4.876 n/a 28.19 n/a
[sage] S=4096 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
4096 256 n/a 16.888 n/a 32.55 n/a
────────────────────────────────────────────────────────────────────
SageAttn gfx12 wins 10/10 configs | avg speedup vs FA2: 1.706x
→ SageAttention gfx12 native is faster on average.
── Numerical sanity check (S=1024, D=64) ──
max |sage - fa2| : 0.02026
mean|sage - fa2| : 0.001528
✓ outputs are numerically close (expected for INT8 quant)Quick chart:
Great results overall! 🚀 @jammm Is there a limitation preventing support for head_dim 256? |
Oh, I thought you're on Windows - never mind then. My point was that if someone doesn't have FA-2 CK, it automatically uses Triton, but the script still prints CK, which is a bit misleading. |
There is. I'm working on bringing parity vs the CUDA path. |
I see. Again, huge respect for the work you put into this. At first, we only had SDPA Math on RDNA4 (Windows). Oh, and it all started with your PyTorch wheels on Windows. 👀 |
|
I referred to this Japanese article to compile Flash Attention 2.8.4 (CK backend) on Windows: |
It's kind of all over the place with questions like "Can fa4-v4.0.0.beta be built?" The answer is no - it can't. We can build FA3, though. It also has FP8 support on RDNA4, for example. Also, that |
|
Finally able to compile Windows multi arch @jammm I had to remove the include dirs
I was hitting duplicate headers (sent you a Discord message about that). Not sure if related to torch, multiarch or what (fresh venv) Compared against FA2 CK backend. Similar results to @trfmk123
Using the Torchmark script
|
Windows ROCm Build - Image Corruption BugHi @jammm, I successfully built SageAttention on Windows with ROCm 7.13 + gfx1201 (RX 9070 XT), but the generated images come out as pure noise. Environment
Build Issues & WorkaroundsThe code couldn't be compiled as-is on Windows due to math function resolution issues. The following functions were not recognized by the HIP compiler on Windows:
Workaround used to get it to compile:
ProblemWith the above workaround, the build succeeds and SageAttention loads, but all generated images are corrupted (pure noise). Without I also tried replacing with This suggests the math function substitution is producing incorrect results in the GPU kernel, likely affecting the quantization scale calculation in QuestionIs there a correct way to resolve Speed improvement without the bug would be significant — the workflow went from 34s → 23s in testing with ROCm 7.2 triton fallback. Thanks for the great work on this PR! Update: I tried again from a clean checkout and found that the pure-noise output was likely caused by my previous workaround using A less aggressive workaround now builds and produces correct images on my setup. Additional workaround that worked for me:
After this, SageAttention imports and ComfyUI generates correct images. Current result:
Environment:
One more Windows-specific issue: So the current status is: the gfx12 native backend can work on Windows/gfx1201, but Windows HIP needs a few compile/import fixes. The previous pure-noise result should probably be treated as caused by my bad math-function workaround, not necessarily a kernel bug. |
|
@qweqweewqe7-create This looks like an environment-specific issue on your side rather than a universal problem. Also, |
|
@qweqweewqe7-create over the past days I have compiled it several times using nightlies of 20260416 which has a newer torch not currently available with newer nightly builds, and I had no issues with compiling nor running it. It might be the specific nightly you're trying to build against is (partially) broken? |








Summary
Adds a native ROCm gfx12 backend for SageAttention on RDNA4, including:
sageattnAPIBuild
Windows
setup.pydiscovers ROCm throughrocm-sdk, setsROCM_HOME, adds the ROCm LLVM/bin paths, and defaults the Windows compiler settings toclang-cl. Users still need to run from an initialized Visual Studio shell.Optional cross-build target:
Linux
pip install --no-build-isolation -v .Optional cross-build target:
PYTORCH_ROCM_ARCH=gfx1201 pip install --no-build-isolation -v .Correctness
Validated native gfx12 output against FlashAttention:
Additional runtime compatibility smoke:
torch.Size([2, 14040, 12, 128])torch.Size([2, 512, 12, 128])0.036942, std ratio0.999413, cosine0.999318versus v10.0086-0.0100, std ratio about0.99997-1.00006versus v1Performance
Measured on gfx1201 / Radeon RX 9070 XT, B=1, H=32, S=1K/2K/4K/8K. FlashAttention comparison uses the installed FlashAttention package in the ROCm venv.
ComfyUI Wan2.1 fp8 Workload
Tested Wan2.1 1.4B with fp8 model weights on a Radeon RX 9070 XT:
--use-pytorch-cross-attention: 2.53 s/it diffusion steps, 120s total--use-sage-attention: 1.78 s/it diffusion steps, 100s totalThe diffusion steps run about 42% faster with the gfx12 SageAttention v2 kernels. The native gfx12 path uses int8 WMMA for QK and fp8 WMMA for PV, with compiled ISA containing
v_wmma_i32_16x16x16_iu8andv_wmma_f32_16x16x16_fp8_fp8.In both runs, VAE decode is the current bottleneck because it exceeds the 16GB VRAM on the 9070 XT. Higher-VRAM RDNA4 cards should avoid this issue.
This was measured with MIOpen disabled, which is ComfyUI's current default behavior. With
COMFYUI_ENABLE_MIOPEN=1, the first run took 265s overall and the second run completed in 84.5s.Versus FlashAttention
Most large-shape fp8/fp16 cases are faster than FlashAttention. Remaining short-shape gaps are mainly fp16 D64 causal at 1K-4K and fp8 causal D128 at 1K.
Additional focused tail-shape check after the smooth-K quality fix:
Versus SageAttention v1 fp8
Latest fp8 native-vs-v1 sweep, B=1, H=32, S=1K/2K/4K/8K, D=16/64/128: