Skip to content

Initial Windows ROCm build support for FlashAttention-2 ROCm/aiter Triton backend#2428

Closed
0xDELUXA wants to merge 1 commit intoROCm:mainfrom
0xDELUXA:fa2-triton-win-support
Closed

Initial Windows ROCm build support for FlashAttention-2 ROCm/aiter Triton backend#2428
0xDELUXA wants to merge 1 commit intoROCm:mainfrom
0xDELUXA:fa2-triton-win-support

Conversation

@0xDELUXA
Copy link
Copy Markdown
Contributor

@0xDELUXA 0xDELUXA commented Mar 23, 2026

Motivation

Since AMD migrated FlashAttention-2's Triton backend to ROCm/aiter, Windows users with AMD GPUs have been unable to build FA-2 at all - the entire build pipeline assumes Linux. This PR adds the necessary platform support so FA-2 can be built and used on Windows via the ROCm/HIP SDK.

Note

This PR depends on Windows build support being merged in Dao-AILab/flash-attention first (or applied locally). See the corresponding PR: Dao-AILab/flash-attention#2384

Technical Details

cpp_extension.py

  • Add Windows ROCm/HIP SDK discovery via HIP_PATH, bundled venv SDK packages (_rocm_sdk_devel, _rocm_sdk_core), and standard install paths (C:\Program Files\AMD\ROCm\*)
  • Gate Linux-only compiler flags (-fPIC, -mcmodel=large, ELF linker flags) behind sys.platform != "win32" checks
  • Fix ninja build file path escaping for Windows drive letters (E:/fooE$:/foo)
  • Use .pyd/.dll extensions and correct Windows linker flags (/SUBSYSTEM:WINDOWS, /DLL) for Python extension modules
  • Add _quote() helper that uses double-quotes on Windows instead of shlex.quote() (single quotes break hipcc include paths)

chip_info.py

  • Add hipinfo fallback alongside rocminfo for GPU/CU detection (rocminfo is not shipped with the Windows ROCm SDK)
  • Refactor arch and CU count parsing into shared _extract_gfx_from_output() / _extract_cu_from_output() helpers
  • Identify gfx942 SKUs by CU count on Windows instead of PCI chip ID (libamdhip64 ctypes probe is Linux-only)

core.py

  • Resolve --offload-arch=native on Windows (hipcc.exe does not support it; replace with detected gfx string)
  • Use temp files instead of /dev/null and shell pipes in hip_flag_checker() and check_LLVM_MAIN_REVISION()
  • Skip /proc/sys/kernel/numa_balancing check on Windows
  • Replace os.system("rm -rf ...") with shutil.rmtree / glob + os.remove

parallel_state.py

  • Guard torch.distributed import for Windows ROCm builds where it may be unavailable

setup.py

  • Skip Linux-specific hsa/ assembly blobs on Windows
  • Move pre-build kernel compilation inside NinjaBuildExtension.run() to defer heavy imports until build time

Test Plan

  • Source build and install (using both PRs for now):
git clone -b fa2-aiter-triton-win-support https://github.com/0xDELUXA/flash-attention.git
cd flash-attention\third_party
git clone -b fa2-triton-win-support https://github.com/0xDELUXA/aiter.git
cd ..
$env:ENABLE_CK = "0"
$env:PREBUILD_KERNELS = "0"
$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE"
pip install --no-build-isolation -e .
  • Run basic tests via from flash_attn import flash_attn_func.

Test Result

  • Successfully built FlashAttention-2 with aiter on Windows with an AMD GPU (gfx1200), ROCm 7.13.0a20260321, PyTorch 2.12.0a0+rocm7.13.0a20260321, Python 3.12.
  • All tests passed.

Discovery

aiter.ops.triton.attention.mha.flash_attn_func works directly on Windows when a tuned config JSON is placed at aiter/ops/triton/configs/gfxNNNN-MHA-DEFAULT.json. Benchmarking against flash_attn.flash_attn_func on gfx1200 shows a 1.5x geomean speedup in favor of the aiter Triton kernel.

Used this benchmark script with this file placed in: aiter/ops/triton/configs/ (otherwise it won’t find the config file and will error out).

Results:

Importing aiter Triton backend... [aiter] import [module_aiter_enum] under C:\flash-attention\third_party\aiter\aiter\jit\module_aiter_enum.pyd
OK
Importing flash_attn... OK

PyTorch    : 2.12.0a0+rocm7.13.0a20260321
GPU        : AMD Radeon RX 9060 XT
aiter ver  : unknown  ->  aiter.ops.triton.attention.mha
flash_attn : 2.8.4  ->  flash_attn
dtype      : bf16  |  warmup=20  rep=100

[ 1/12] b1 hq32 hk32 sq1 sk4096 d128
         aiter  : 0.231 ms  0.07 TFLOPS
         fa     : 0.430 ms  0.04 TFLOPS
         winner : aiter  (1.863x)
[ 2/12] b1 hq32 hk32 sq512 sk4096 d128
         aiter  : 1.090 ms  7.88 TFLOPS
         fa     : 1.968 ms  4.37 TFLOPS
         winner : aiter  (1.806x)
[ 3/12] b1 hq32 hk32 sq2048 sk4096 d128
         aiter  : 3.895 ms  8.82 TFLOPS
         fa     : 6.978 ms  4.92 TFLOPS
         winner : aiter  (1.791x)
[ 4/12] b1 hq32 hk8 sq1 sk4096 d128
         aiter  : 0.202 ms  0.08 TFLOPS
         fa     : 0.196 ms  0.09 TFLOPS
         winner : flash_attn  (0.971x)
[ 5/12] b1 hq32 hk8 sq512 sk4096 d128
         aiter  : 0.988 ms  8.69 TFLOPS
         fa     : 1.432 ms  6.00 TFLOPS
         winner : aiter  (1.449x)
[ 6/12] b1 hq32 hk8 sq2048 sk4096 d128
         aiter  : 2.740 ms  12.54 TFLOPS
         fa     : 4.514 ms  7.61 TFLOPS
         winner : aiter  (1.647x)
[ 7/12] b4 hq32 hk32 sq1 sk4096 d128
         aiter  : 1.000 ms  0.07 TFLOPS
         fa     : 1.231 ms  0.06 TFLOPS
         winner : aiter  (1.231x)
[ 8/12] b4 hq32 hk32 sq512 sk4096 d128
         aiter  : 5.640 ms  6.09 TFLOPS
         fa     : 13.037 ms  2.64 TFLOPS
         winner : aiter  (2.312x)
[ 9/12] b4 hq32 hk32 sq2048 sk4096 d128
         aiter  : 15.426 ms  8.91 TFLOPS
         fa     : 41.548 ms  3.31 TFLOPS
         winner : aiter  (2.693x)
[10/12] b4 hq32 hk8 sq1 sk4096 d128
         aiter  : 0.590 ms  0.11 TFLOPS
         fa     : 0.459 ms  0.15 TFLOPS
         winner : flash_attn  (0.777x)
[11/12] b4 hq32 hk8 sq512 sk4096 d128
         aiter  : 3.506 ms  9.80 TFLOPS
         fa     : 5.835 ms  5.89 TFLOPS
         winner : aiter  (1.665x)
[12/12] b4 hq32 hk8 sq2048 sk4096 d128
         aiter  : 11.094 ms  12.39 TFLOPS
         fa     : 18.862 ms  7.29 TFLOPS
         winner : aiter  (1.700x)

===================================================================================================================
Config                                            aiter ms     fa ms     Δ ms   Speedup   aiter TFLOPS  fa TFLOPS
===================================================================================================================
b1 hq32 hk32 sq1 sk4096 d128                         0.231     0.430   -0.199    1.863x ✓           0.07       0.04
b1 hq32 hk32 sq512 sk4096 d128                       1.090     1.968   -0.878    1.806x ✓           7.88       4.37
b1 hq32 hk32 sq2048 sk4096 d128                      3.895     6.978   -3.083    1.791x ✓           8.82       4.92
b1 hq32 hk8 sq1 sk4096 d128                          0.202     0.196   +0.006    0.971x ✗           0.08       0.09
b1 hq32 hk8 sq512 sk4096 d128                        0.988     1.432   -0.444    1.449x ✓           8.69       6.00
b1 hq32 hk8 sq2048 sk4096 d128                       2.740     4.514   -1.774    1.647x ✓          12.54       7.61
b4 hq32 hk32 sq1 sk4096 d128                         1.000     1.231   -0.231    1.231x ✓           0.07       0.06
b4 hq32 hk32 sq512 sk4096 d128                       5.640    13.037   -7.397    2.312x ✓           6.09       2.64
b4 hq32 hk32 sq2048 sk4096 d128                     15.426    41.548  -26.122    2.693x ✓           8.91       3.31
b4 hq32 hk8 sq1 sk4096 d128                          0.590     0.459   +0.132    0.777x ✗           0.11       0.15
b4 hq32 hk8 sq512 sk4096 d128                        3.506     5.835   -2.330    1.665x ✓           9.80       5.89
b4 hq32 hk8 sq2048 sk4096 d128                      11.094    18.862   -7.768    1.700x ✓          12.39       7.29
-------------------------------------------------------------------------------------------------------------------
aiter wins / total                                      10 / 12
Geomean speedup (aiter / flash_attn)                 1.576x
===================================================================================================================

Comment on lines +31 to +34
if sys.platform == "win32":
if " " in s:
return '"' + s.replace('"', '\\"') + '"'
return s
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an undocumented function in subprocess for this that will handle all cases. Find my open pull request in pytorch, I used it there.

@0xDELUXA 0xDELUXA changed the title Initial FA-2 Triton Windows build support Initial Windows ROCm build support for FlashAttention-2 ROCm/aiter Triton backend Mar 23, 2026
@0xDELUXA 0xDELUXA force-pushed the fa2-triton-win-support branch from 59de888 to 059ebbf Compare March 23, 2026 16:21
@0xDELUXA
Copy link
Copy Markdown
Contributor Author

Closing this in favor of the continuation PR by @micmelesse. Thanks for picking it up - I’m happy to help with anything if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants