Skip to content

Add native gfx12 attention backend#368

Open
jammm wants to merge 16 commits into
thu-ml:mainfrom
jammm:jam/gfx12
Open

Add native gfx12 attention backend#368
jammm wants to merge 16 commits into
thu-ml:mainfrom
jammm:jam/gfx12

Conversation

@jammm
Copy link
Copy Markdown

@jammm jammm commented May 13, 2026

Summary

Adds a native ROCm gfx12 backend for SageAttention on RDNA4, including:

  • gfx12 native QK int8 attention paths for fp8 and fp16 value modes
  • fp8 support for D16/D64/D128 and fp16 support for D16/D64
  • automatic gfx12 runtime dispatch from the public sageattn API
  • internal sequence padding plus logical KV tail masking so non-64 sequence lengths do not fall back
  • HIP build integration for Windows/Linux ROCm PyTorch wheels
  • HIP compatibility fixes for fused/smooth_k support

Build

Windows

setup.py discovers ROCm through rocm-sdk, sets ROCM_HOME, adds the ROCm LLVM/bin paths, and defaults the Windows compiler settings to clang-cl. Users still need to run from an initialized Visual Studio shell.

# Activate Visual Studio environment
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }

# Activate the virtual environment
.\venv\Scripts\Activate.ps1

pip install --no-build-isolation -v .

Optional cross-build target:

$env:PYTORCH_ROCM_ARCH = "gfx1201"
pip install --no-build-isolation -v .

Linux

pip install --no-build-isolation -v .

Optional cross-build target:

PYTORCH_ROCM_ARCH=gfx1201 pip install --no-build-isolation -v .

Correctness

Validated native gfx12 output against FlashAttention:

  • 50/50 cases passed
  • fp8: D16/D64/D128
  • fp16: D16/D64
  • causal and non-causal
  • S64/S128/S1024
  • HND and NHD layouts
  • GQA Hq/Hkv = 8/2
  • bf16 fp8 path
  • smooth_k fp8/fp16 path

Additional runtime compatibility smoke:

  • ComfyUI Wan2.1 NHD cross-attention runs through the gfx12 native path without fallback:
    • q: torch.Size([2, 14040, 12, 128])
    • k/v: torch.Size([2, 512, 12, 128])
  • The fp8 NHD D128 non-causal smooth path avoids materializing a padded Q copy for tail shapes while preserving CUDA-compatible smooth-K behavior.
  • Quality checks against SageAttention v1 on fp8/fp16 NHD cases preserve output scale:
    • Wan-style fp8 D128 tail: rel RMSE 0.036942, std ratio 0.999413, cosine 0.999318 versus v1
    • fp16 D64: rel RMSE about 0.0086-0.0100, std ratio about 0.99997-1.00006 versus v1

Performance

Measured on gfx1201 / Radeon RX 9070 XT, B=1, H=32, S=1K/2K/4K/8K. FlashAttention comparison uses the installed FlashAttention package in the ROCm venv.

ComfyUI Wan2.1 fp8 Workload

Tested Wan2.1 1.4B with fp8 model weights on a Radeon RX 9070 XT:

  • --use-pytorch-cross-attention: 2.53 s/it diffusion steps, 120s total
  • --use-sage-attention: 1.78 s/it diffusion steps, 100s total

The diffusion steps run about 42% faster with the gfx12 SageAttention v2 kernels. The native gfx12 path uses int8 WMMA for QK and fp8 WMMA for PV, with compiled ISA containing v_wmma_i32_16x16x16_iu8 and v_wmma_f32_16x16x16_fp8_fp8.

In both runs, VAE decode is the current bottleneck because it exceeds the 16GB VRAM on the 9070 XT. Higher-VRAM RDNA4 cards should avoid this issue.

This was measured with MIOpen disabled, which is ComfyUI's current default behavior. With COMFYUI_ENABLE_MIOPEN=1, the first run took 265s overall and the second run completed in 84.5s.

Versus FlashAttention

Mode Speedup vs FlashAttention
fp8 non-causal 1.20x-2.04x, geo 1.44x
fp8 causal 0.87x-1.73x, geo 1.27x
fp16 non-causal 0.94x-1.68x, geo 1.29x
fp16 causal 0.67x-1.45x, geo 1.03x

Most large-shape fp8/fp16 cases are faster than FlashAttention. Remaining short-shape gaps are mainly fp16 D64 causal at 1K-4K and fp8 causal D128 at 1K.

Additional focused tail-shape check after the smooth-K quality fix:

Shape Mode SageAttention gfx12 FlashAttention Speedup
B=2, Sq=14040, Sk=512, H=12, D=128 fp8 non-causal NHD 0.884 ms 1.017 ms 1.15x

Versus SageAttention v1 fp8

Latest fp8 native-vs-v1 sweep, B=1, H=32, S=1K/2K/4K/8K, D=16/64/128:

Mode Speedup vs SageAttention v1 fp8
fp8 non-causal 2.48x-9.38x, geo 4.46x
fp8 causal 2.47x-12.07x, geo 5.38x

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

@0xDELUXA if possible can you try this PR branch on comfyui ?

@0xDELUXA
Copy link
Copy Markdown

@0xDELUXA if possible can you try this PR branch on comfyui ?

Sure! I’m AFK right now, but I’ll try it later. Looking great so far!

@trfmk123
Copy link
Copy Markdown

I am using this implementation of SageAttention with the Z-Image Turbo model on a 9070XT GPU.

Issue:
There is a noticeable drop in precision. Specifically, the output images suffer from "color fading" — the saturation is much lower than expected, and the overall image looks pale.

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

I am using this implementation of SageAttention with the Z-Image Turbo model on a 9070XT GPU.

Issue: There is a noticeable drop in precision. Specifically, the output images suffer from "color fading" — the saturation is much lower than expected, and the overall image looks pale.

Thanks for checking! does this happen with sageattention v1 too? or is it specific to this PR?

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

@trfmk123 just pushed a fix. Can you try again?

@trfmk123
Copy link
Copy Markdown

trfmk123 commented May 14, 2026

@trfmk123 just pushed a fix. Can you try again?

Just tested it, and the image outputs for Z-Image Turbo are completely normal now.
Thanks a lot for the fix and your hard work!

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

@trfmk123 just pushed a fix. Can you try again?

Just tested it, and the image outputs for Z-Image Turbo are completely normal now. Thanks a lot for the fix and your hard work!

Thanks! Perf has regressed a bit due to the fix, but getting the quality right is the main thing. I'll try to do a perf pass.

@trfmk123
Copy link
Copy Markdown

@trfmk123 just pushed a fix. Can you try again?

Just tested it, and the image outputs for Z-Image Turbo are completely normal now. Thanks a lot for the fix and your hard work!

Thanks! Perf has regressed a bit due to the fix, but getting the quality right is the main thing. I'll try to do a perf pass.

I suggest implementing v_scale scaling within the kernel. This would be a great way to fix the desaturation/precision issues without sacrificing any performance.

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

@trfmk123 just pushed a fix. Can you try again?

Just tested it, and the image outputs for Z-Image Turbo are completely normal now. Thanks a lot for the fix and your hard work!

Thanks! Perf has regressed a bit due to the fix, but getting the quality right is the main thing. I'll try to do a perf pass.

I suggest implementing v_scale scaling within the kernel. This would be a great way to fix the desaturation/precision issues without sacrificing any performance.

Good point. The gfx12 code currently takes raw fp8 weights without scaling which isn't desirable.

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 14, 2026

I'm using ROCm 7.13.0a20260504, and rocm-sdk init didn't populate _rocm_sdk_devel\lib\llvm\lib\clang\ (might be a local issue), so I had to set HIPCC_APPEND_FLAGS to point to _rocm_sdk_core\lib\llvm\lib\clang\23\include, otherwise the build fails with fatal error: '__clang_hip_runtime_wrapper.h' file not found.

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

I'm using ROCm 7.13.0a20260504, and rocm-sdk init didn't populate _rocm_sdk_devel\lib\llvm\lib\clang\ (might be a local issue), so I had to set HIPCC_APPEND_FLAGS to point to _rocm_sdk_core\lib\llvm\lib\clang\23\include, otherwise the build fails with fatal error: '__clang_hip_runtime_wrapper.h' file not found.

Did you install the "rocm-sdk-devel" pip wheel? it's part of the "rocm[libraries,devel]" when you to a pip install.
If that didn't help, try again on powershell/cmd in admin mode

@0xDELUXA
Copy link
Copy Markdown

Did you install the "rocm-sdk-devel" pip wheel? it's part of the "rocm[libraries,devel]" when you to a pip install. If that didn't help, try again on powershell/cmd in admin mode

I used this command to install: python -m pip install --pre --index-url https://rocm.nightlies.amd.com/v2-staging/gfx120X-all/ torch torchvision torchaudio rocm[devel,libraries]

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

Did you install the "rocm-sdk-devel" pip wheel? it's part of the "rocm[libraries,devel]" when you to a pip install. If that didn't help, try again on powershell/cmd in admin mode

I used this command to install: python -m pip install --pre --index-url https://rocm.nightlies.amd.com/v2-staging/gfx120X-all/ torch torchvision torchaudio rocm[devel,libraries]

Can you try pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio rocm[libraries,devel] in a fresh venv?

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

Did you install the "rocm-sdk-devel" pip wheel? it's part of the "rocm[libraries,devel]" when you to a pip install. If that didn't help, try again on powershell/cmd in admin mode

I used this command to install: python -m pip install --pre --index-url https://rocm.nightlies.amd.com/v2-staging/gfx120X-all/ torch torchvision torchaudio rocm[devel,libraries]

Can you try pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio rocm[libraries,devel] in a fresh venv?

Actually, this one goes into an endless loop, so don't try that. Do pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ rocm[libraries,devel] first, then pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 14, 2026

Can you try pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio rocm[libraries,devel] in a fresh venv?

This installs ROCm 7.13.0a20260416, and after running rocm-sdk init, the _rocm_sdk_devel/lib/llvm/lib/clang/23/include folder is actually created.

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

Can you try pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ torch torchvision torchaudio rocm[libraries,devel] in a fresh venv?

This installs ROCm 7.13.0a20260416, and after running rocm-sdk init, the _rocm_sdk_devel/lib/llvm/lib/clang/23/include folder is actually created.

Yeah that should work. I use the same wheels.

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 14, 2026

No errors when building now. But these wheels are close to a month old. Don’t we have any newer ones that are “stable”?

@jammm
Copy link
Copy Markdown
Author

jammm commented May 14, 2026

They're in the index but not being picked up for some reason. These ones are from may 11
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ "rocm[libraries,devel]==7.13.0a20260511" "torch==2.10.0+rocm7.13.0a20260511" "torchaudio==2.10.0+rocm7.13.0a20260511" "torchvision==0.25.0+rocm7.13.0a20260511"

@0xDELUXA
Copy link
Copy Markdown

I don't really understand what you meant by:

With `--use-sage-attention`, ComfyUI reported `Using xformers attention in VAE`.

in #368 (comment).

AFAIK, if xformers isn't installed as a standalone package, ComfyUI won't print that message. For me, regardless of the attention backend used (SDPA, Sage, Flash), it always prints Using split attention in VAE unless I install xformers separately. This is partly why I created this node.

Could this be happening on your end because you're also working on ROCm/xformers#87?

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 14, 2026

I also have a large FLUX.2 FP8 multi-stage workflow, and the only change I make is swapping Flash attention for Sage attention. Across multiple runs, runtime improves from 240 s/it to 60 s/it. At first, I was skeptical that this was actually the case, but it turned out to be true. [gfx1200 + torch 2.13.0a0+rocm7.13.0a20260416 here.]

@crashingalexsan
Copy link
Copy Markdown

My issue is that: pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ "rocm[libraries,devel]==7.13.0a20260511" "torch==2.10.0+rocm7.13.0a20260511" "torchaudio==2.10.0+rocm7.13.0a20260511" "torchvision==0.25.0+rocm7.13.0a20260511" doesn't seem to include this fix, and in certain FP8 workflows it causes a comfy-kitchen "Fatal access violation" error. When I install 7.13.0a20260416 from v2-staging (with torch 2.13), the error disappears. I assume it's only included in 2.11 and later.

Hello everyone. There is a problem with the runners running in the CI when running smoke tests. Alredy reported and Scott took a look at it. Ive been using v2-staging for almost 3 weeks for per family releases.

@jammm Ive been using multi arch releases to test them too. Will try to build and test against multi arch releases later today. Any specific script or test ? Or a general run is okay ?

@DrywFiltiarn
Copy link
Copy Markdown

I'm running tests now with this version to see how it compares against SA v1 that I was using before, will report my findings as soon as I have results.

Out of curiosity, is there a way to verify that the SA v2 path is actually used?

@trfmk123
Copy link
Copy Markdown

Hi everyone,

I've provided a simple speed test script here: Google Drive Link

I tested this on an AMD 9070XT using the Multi-arch releases (untested packages). My environment details:

PyTorch: 2.12.0+rocm7.13.0a20260514

ROCm: 7.13.26192

image

Hope this script helps!

@crosson
Copy link
Copy Markdown

crosson commented May 15, 2026

FYI this is a crazy speed boost on my r9700 pro. Reducing LTX and WAN render times down by considerable amounts.

LTX 2.3 736x480 20 seconds
429s No Sage Attention
307s First Render
254s Second Render

WAN 2.2 736x480 5 seconds
307s No Sage Attention
232s First Render
163s Second Render

JuggernautXL down to 8s a render with 30 steps.

Bravo to those who worked on this. I'm happy to share logs or any data anyone wants. I'm running on an r9700 pro.

FYI I power restrict my card at 200w so many users will have better numbers.

@ouco1986
Copy link
Copy Markdown

ouco1986 commented May 15, 2026

9070XT GPU After testing under Linux, it was found that the speed slowdown began with the update "6b71b93Support for native sequence tails in graphics 12". Subsequent updates did not show any further significant slowdown

@DrywFiltiarn
Copy link
Copy Markdown

DrywFiltiarn commented May 15, 2026

I'm running tests now with this version to see how it compares against SA v1 that I was using before, will report my findings as soon as I have results.

Out of curiosity, is there a way to verify that the SA v2 path is actually used?

My findings:
Model: Flux 2 Klein Base 9B
Resolution: 1536x864

SageAttention v1 - 281.80 seconds - clip encode (x2) + diffusion (x2) + vae (x2)
SageAttention v2 - 182.59 seconds - clip encode (x2) + diffusion (x2) + vae (x2)

My workflow is a high-low pass setup that uses 2x SamplerCustomAdvanced, each with it's own Clip Text Encoder and it's own vae for preview/saving results.

This was only a single generation (identical on both cases! Only variable that changed between the two was SageAttention v1 to v2) comparison, so not entirely conclusive yet. But the results are impressive if they hold up like this. Doing the math on this single comparison run there's a 35% improvement in generation performance running my full workflow.

@0xDELUXA
Copy link
Copy Markdown

Hi everyone,

I've provided a simple speed test script here: Google Drive Link

I tested this on an AMD 9070XT using the Multi-arch releases (untested packages). My environment details:

PyTorch: 2.12.0+rocm7.13.0a20260514

ROCm: 7.13.26192

image Hope this script helps!

@trfmk123 Which FlashAttention version are you using? I assume it isn't the latest one from aiter. Are you using v2.8.3?
Also, the script shouldn't print "FlashAttention-2 CK" there at all.

@0xDELUXA
Copy link
Copy Markdown

Out of curiosity, is there a way to verify that the SA v2 path is actually used?

According to the PR description, it prioritizes this implementation over Triton on RDNA4:

  • automatic gfx12 runtime dispatch from the public sageattn API

@trfmk123
Copy link
Copy Markdown

@trfmk123 Which FlashAttention version are you using? I assume it isn't the latest one from aiter. Are you using v2.8.3? Also, the script shouldn't print "FlashAttention-2 CK" there at all.

my setup is running Flash Attention 2.8.4 via the composable_kernel (ck) backend.

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 15, 2026

Quick script I used to bench this native Sage 2.2.0 implementation against Flash 2.8.4 (aiter triton) on Windows.

Results:

(venv) PS C:\> python bench_sage2_gfx12_vs_fa2.py

Device : AMD Radeon RX 9060 XT  (gfx1200)
Config : dtype=fp16  causal=False  B=1  Hq=32  Hkv=32  warmup=50  iters=200

Loading implementations …
  [sage] using sageattn_qk_int8_pv_gfx12_native (direct)
C:\ComfyUI\venv\Lib\site-packages\flash_attn\flash_attn_interface.py:17: UserWarning: flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation
  warnings.warn("flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation")
[aiter] Windows: CK and HIP ops are not available. Triton ops only.
  [fa2]  using flash_attn.flash_attn_func

     S     D    Sage ms     FA2 ms   Sage TF/s    FA2 TF/s   Speedup
────────────────────────────────────────────────────────────────────
   512    64      0.240      0.320        8.96        6.71    1.335x
  1024    64      0.391      0.567       21.95       15.15    1.449x
  2048    64      0.710      1.284       48.37       26.75    1.808x
  4096    64      2.110      4.218       65.14       32.58    1.999x
  8192    64      7.314     16.832       75.16       32.66    2.301x
   512   128      0.290      0.398       14.81       10.80    1.371x
  1024   128      0.603      0.797       28.47       21.55    1.322x
  2048   128      1.444      2.117       47.59       32.46    1.466x
  4096   128      4.216      8.041       65.20       34.19    1.907x
  8192   128     14.962     31.508       73.49       34.90    2.106x
  [sage] S=1024 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
  1024   256      n/a      1.654      n/a       20.77      n/a
  [sage] S=2048 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
  2048   256      n/a      4.876      n/a       28.19      n/a
  [sage] S=4096 D=256 FAILED: gfx12 fp8 value path currently supports head_dim 16, 64, or 128.
  4096   256      n/a     16.888      n/a       32.55      n/a
────────────────────────────────────────────────────────────────────

SageAttn gfx12 wins 10/10 configs  |  avg speedup vs FA2: 1.706x
→ SageAttention gfx12 native is faster on average.

── Numerical sanity check (S=1024, D=64) ──
  max |sage - fa2| : 0.02026
  mean|sage - fa2| : 0.001528
  ✓ outputs are numerically close (expected for INT8 quant)

Quick chart:

Screenshot 2026-05-15 174442

Great results overall! 🚀

@jammm Is there a limitation preventing support for head_dim 256?

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 15, 2026

my setup is running Flash Attention 2.8.4 via the composable_kernel (ck) backend.

Oh, I thought you're on Windows - never mind then. My point was that if someone doesn't have FA-2 CK, it automatically uses Triton, but the script still prints CK, which is a bit misleading.

@jammm
Copy link
Copy Markdown
Author

jammm commented May 15, 2026

@jammm Is there a limitation preventing support for head_dim 256?

There is. I'm working on bringing parity vs the CUDA path.
Having said that, the CUDA path also rejects head_dim > 128, so this is not specific to gfx12. It's the blackwell specific code (sageattention v3) which supports 256. The scope of this PR is sageattention v2 though, since sageattention v3 needs fp4 support, which RDNA4 doesn't have.

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 15, 2026

@jammm Is there a limitation preventing support for head_dim 256?

There is. I'm working on bringing parity vs the CUDA path.
Having said that, the CUDA path also rejects head_dim > 128, so this is not specific to gfx12. It's the blackwell specific code (sageattention v3) which supports 256. The scope of this PR is sageattention v2 though, since sageattention v3 needs fp4 support, which RDNA4 doesn't have.

I see.

Again, huge respect for the work you put into this.

At first, we only had SDPA Math on RDNA4 (Windows).
After that, you helped make AOTriton a thing.
Later, your contributions to triton-windows made Sage1 and FA2 possible to use.
And now here we are with Sage2.

Oh, and it all started with your PyTorch wheels on Windows. 👀

@trfmk123
Copy link
Copy Markdown

trfmk123 commented May 15, 2026

@0xDELUXA

Oh, I thought you're on Windows - never mind then. My point was that if someone doesn't have FA-2 CK, it automatically uses Triton, but the script still prints CK, which is a bit misleading.

My OS is Windows 11. Below are my test results.
image

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 15, 2026

My OS is Windows 11. Below are my test results. image

So you’ve built FA2 CK on Windows? Not many people have done this...

Based on your results, FA2 CK outperforms Sage2 at smaller batch sizes (which wasn't the case for FA2 Triton), while Sage2 takes the lead at larger ones.

@trfmk123
Copy link
Copy Markdown

@0xDELUXA

So you’ve built FA2 CK on Windows? Not many people have done this...

Based on your results, FA2 CK outperforms Sage2 at smaller batch sizes (which wasn't the case for FA2 Triton), while Sage2 takes the lead at larger ones.

I referred to this Japanese article to compile Flash Attention 2.8.4 (CK backend) on Windows:
https://note.com/lpp/n/nf24da8645c3c

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 15, 2026

I referred to this Japanese article to compile Flash Attention 2.8.4 (CK backend) on Windows: https://note.com/lpp/n/nf24da8645c3c

It's kind of all over the place with questions like "Can fa4-v4.0.0.beta be built?" The answer is no - it can't. We can build FA3, though. It also has FP8 support on RDNA4, for example. Also, that setup.py patch isn't really needed after Dao-AILab/flash-attention#2517.

@crashingalexsan
Copy link
Copy Markdown

Finally able to compile Windows multi arch

@jammm I had to remove the include dirs

image

I was hitting duplicate headers (sent you a Discord message about that). Not sure if related to torch, multiarch or what (fresh venv)


Compared against FA2 CK backend. Similar results to @trfmk123

image

Using the Torchmark script

image

@woct0rdho
Copy link
Copy Markdown

Is there a limitation preventing support for head_dim 256?

@0xDELUXA FYI, head dim 256 is impossible on Nvidia sm90 (without some splitting) due to constraint of WGMMA size, and it's possible on Nvidia consumer architectures. I've tried it in #329 . I haven't tried how to do it on AMD GPU.

@qweqweewqe7-create
Copy link
Copy Markdown

qweqweewqe7-create commented May 17, 2026

Windows ROCm Build - Image Corruption Bug

Hi @jammm, I successfully built SageAttention on Windows with ROCm 7.13 + gfx1201 (RX 9070 XT), but the generated images come out as pure noise.

Environment

  • OS: Windows 11
  • GPU: AMD RX 9070 XT (gfx1201)
  • ROCm: 7.13.0a20260511
  • PyTorch: 2.10.0+rocm7.13
  • Python: 3.12

Build Issues & Workarounds

The code couldn't be compiled as-is on Windows due to math function resolution issues. The following functions were not recognized by the HIP compiler on Windows:

  • fmaxf, fabsf, fminf, nearbyintfno matching function for call
  • max in reduction_utils_hip.cuhuse of undeclared identifier

Workaround used to get it to compile:

  • Replaced fmaxf(__builtin_fmaxf( etc. in .cu source files
  • Replaced val = max(...) with ternary operator in reduction_utils_hip.cuh

Problem

With the above workaround, the build succeeds and SageAttention loads, but all generated images are corrupted (pure noise). Without --use-sage-attention, images generate correctly.

I also tried replacing with ::fmaxf(, __ocml_fmax_f32( etc. but the same corruption occurs.

This suggests the math function substitution is producing incorrect results in the GPU kernel, likely affecting the quantization scale calculation in qk_int_sv_gfx12_native.

Question

Is there a correct way to resolve fmaxf/fabsf on Windows HIP compiler? Or is there a known fix for this?

Speed improvement without the bug would be significant — the workflow went from 34s → 23s in testing with ROCm 7.2 triton fallback.

Thanks for the great work on this PR!

Update:

I tried again from a clean checkout and found that the pure-noise output was likely caused by my previous workaround using __builtin_fmaxf, __builtin_fabsf, etc.

A less aggressive workaround now builds and produces correct images on my setup.

Additional workaround that worked for me:

  • Removed explicit include_dirs=include_dirs from the ROCm CUDAExtension entries in setup.py
  • Added simple device helper functions for fabsf, fmaxf, fminf instead of replacing them with __builtin_*
  • Reused the existing v_cvt_i32_f32 inline asm rounding path instead of nearbyintf
  • Added local max / min helpers in reduction_utils.cuh
  • Replaced a few host-side std::max(q_heads, kv_heads) / ::max(...) grid dimension expressions with explicit ternaries, because hipify/Windows HIP produced ::max errors

After this, SageAttention imports and ComfyUI generates correct images.

Current result:

  • Baseline ComfyUI: ~34s
  • --use-sage-attention: ~28s
  • About 17.6% lower wall time / ~1.21x speedup

Environment:

  • Windows 11
  • RX 9070 XT / gfx1201
  • Python 3.12
  • PyTorch 2.10.0+rocm7.13.0a20260511
  • ROCm HIP 7.13.26176

One more Windows-specific issue: triton is not available for this environment via the AMD gfx120X wheel index, so importing sageattention failed at first because core.py imports the Triton backend unconditionally. I patched core.py locally to allow import without Triton and use only the gfx12 native backend.

So the current status is: the gfx12 native backend can work on Windows/gfx1201, but Windows HIP needs a few compile/import fixes. The previous pure-noise result should probably be treated as caused by my bad math-function workaround, not necessarily a kernel bug.

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 17, 2026

@qweqweewqe7-create This looks like an environment-specific issue on your side rather than a universal problem. Also, triton-windows needs to be installed separately in the venv - not specifically for this PR, but in general.

@DrywFiltiarn
Copy link
Copy Markdown

@qweqweewqe7-create over the past days I have compiled it several times using nightlies of 20260416 which has a newer torch not currently available with newer nightly builds, and I had no issues with compiling nor running it. It might be the specific nightly you're trying to build against is (partially) broken?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants