[FLYDSL] Add gfx1201 (RDNA4) flash_attn_func backend#2969
Conversation
Adds aiter.ops.flydsl.flydsl_flash_attn_func, a self-attention kernel
optimized for gfx1201 (Radeon AI PRO R9700 / R9600D / RX 9070 XT). The
kernel uses WMMA 16x16x16 on wave32, BLOCK_N=32, rocdl.exp2 intrinsic,
software-pipelined GEMM2, and overlapped V global load.
The Python wrapper:
- Caches built kernels by (num_heads, head_dim, causal, dtype).
- Auto-pads seq_len up to a multiple of BLOCK_M=128 with zero-padded
K/V/Q; the padded query rows are sliced off before returning, so
callers never see the alignment requirement.
- Restricts to self-attention (Lq == Lk); cross-attention raises.
- Validates head_dim >= 64 and head_dim % 32 == 0.
Production validation on Wan2.1-T2V-1.3B (480x832, 81 frames, 30 steps;
self-attn shape S=32760 padded to 32768, B=1, H=12, D=128, bf16):
Backend e2e wall time kernel TFLOPS @ S=32768
---------------- ----------------- -----------------------
SDPA fallback 617.1 s / 10.29m 29.0
Triton FA (BM=256/BN=32, gfx1201)
471.0 s / 7.85m ~38
flydsl_flash_attn 460.9 s / 7.68m 49.4
vs SDPA: 1.34x e2e speedup (25.3% faster); 1.70x at kernel level.
vs Triton FA: 2.1% e2e improvement.
Tests in op_tests/flydsl_tests/test_flydsl_fmha.py cover correctness vs
SDPA on aligned (B=1 S=32768 H=12 D=128 and B=2 S=1024 H=8 D=128) and
unaligned (B=1 S=32760 H=12 D=128, exercises the padding path) shapes,
plus argument validation (cross-attn rejection, head_dim, dtype).
The kernel itself is upstream FlyDSL's gfx1201 best variant (BLOCK_N=32
+ rocdl.exp2 + pipelined GEMM2 + overlapped V load); the Python file is
copied into aiter.ops.flydsl.kernels/ following the existing convention
for FlyDSL-backed kernels in this directory. dtype_to_elem_type is added
to kernels_common for symmetry with the other FlyDSL kernel families.
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Adds a new FlyDSL-backed FlashAttention (FMHA) implementation targeting gfx1201 / RDNA4, exposing it as a public API and adding test coverage for correctness and input validation.
Changes:
- Introduces
aiter.ops.flydsl.flydsl_flash_attn_funcwrapper with kernel build caching and seq_len auto-padding. - Adds the gfx1201 FlyDSL kernel module and a shared
dtype_to_elem_typehelper. - Registers the new API in
aiter.ops.flydsland adds a gfx1201-gated pytest module.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/flydsl_tests/test_flydsl_fmha.py |
New tests for FMHA correctness vs SDPA and argument validation (gfx1201 + FlyDSL gated). |
aiter/ops/flydsl/kernels/kernels_common.py |
Adds dtype_to_elem_type() utility for FlyDSL MLIR scalar type selection. |
aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py |
Adds upstream FlyDSL gfx1201 FlashAttention kernel implementation. |
aiter/ops/flydsl/fmha_kernels.py |
New public wrapper providing caching, padding, validation, and BSHD interface. |
aiter/ops/flydsl/__init__.py |
Exposes flydsl_flash_attn_func when FlyDSL is available and version-compatible. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@sunway513 large quantities of raw mlir types are used in this kernel including arith, scf, memref, vector. Though these are not forbidden but make the codes very difficult to read. I suggest to refactor a little bit. ROCm/FlyDSL#462 You can follow this PR and this skill https://github.com/ROCm/FlyDSL/blob/main/.claude/skills/flydsl-internal-types-cleanup/SKILL.md to do it. It will not affect perf if expected. |
Addresses 5 of 6 Copilot inline review comments on PR ROCm#2969 (the 6th — padded softmax bug in the kernel — is handled separately): #1 (multi-GPU device context): wrap _get_kernel() build and exe() launch in 'with torch.cuda.device(q.device.index):' so callers whose current device differs from q.device get the kernel built and launched on the correct device/stream. #3 (docstring mismatch): module docstring's cache-key list now includes waves_per_eu and daz to match _get_kernel signature. #4 (unused import): drop the unused 'flydsl =' assignment from pytest.importorskip in test file (silences ruff F841). #5 (dtype/causal coverage): add two correctness tests — - test_flydsl_fmha_correctness_f16 (Wan2.1-shape, fp16, non-causal) - test_flydsl_fmha_correctness_causal_small (small bf16, causal) #6 (device validation): at the top of flydsl_flash_attn_func, raise ValueError when q/k/v are not on the same device, and when the active device's gcnArchName is not gfx1201. Bonus tests: - test_flydsl_fmha_correctness_multi_device: runs the kernel on cuda:1 while current device is cuda:0 in a subprocess, exercising the device-context wrap (#1). xfails cleanly if the FlyDSL runtime pins to device 0 internally. - test_flydsl_fmha_rejects_device_mismatch: validates the same-device guard from #6 directly (q on cuda:0, k/v on cuda:1). Verified with black, ruff, and pytest (10/10 PASS) on R9600D MI gfx1201.
Wrapper-level safety guard for the padded-softmax bug raised by Copilot inline comment #2 on PR ROCm#2969. Padded K/V tokens produce QK^T = 0 but exp(0) = 1 still contributes to the softmax denominator and silently scales the output for non-causal attention. Causal mode masks padded positions so it is unaffected. Empirical RCA at aiter-forge-baselines/2969_padded_softmax_rca.md: - Wan2.1 production (S_real=32760, S_pad=32768, ratio=0.024%): cos_min 0.999992, max_abs 0.0008 — safe, indistinguishable from bf16 noise floor. - 50% padding worst case: rel_err 37.3%, max_abs 0.281 — silent output scaling, would corrupt downstream. Implements option (d) from the RCA decision doc (signed off by Peng): hybrid threshold. Non-causal calls with n_pad/seq_len_pad > 0.005 are rejected with a ValueError that points the caller at the three valid remediations (causal=True, pre-pad to multiple of 128, or use a masking-aware kernel). Threshold rationale: 0.5% is the bf16 mantissa precision floor (~0.4%, 7 mantissa bits) plus 1 bit of margin. Production Wan2.1 (0.024%) clears it by 20x, so the hot path stays open while the silent-disaster worst case is closed. Tests added (op_tests/flydsl_tests/test_flydsl_fmha.py): - test_flydsl_fmha_rejects_excessive_padding: B=1, S_real=129 (S_pad=256, 49.6% pad), causal=False — must raise ValueError with "0.5% safety threshold" substring. - test_flydsl_fmha_allows_tight_padding: Wan2.1 case S_real=32760, causal=False — must succeed and match SDPA reference (cos_min >= 0.9999). Regression guard for the production hot path. Validation on R9600D (gfx1201) inside wan-best container, HIP_VISIBLE_DEVICES=4: 10 passed, 2 skipped (multi-GPU only). black --check + ruff check both clean on touched files. Kernel file aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py is intentionally untouched — refactor is in a parallel branch.
|
It has been verified on the 9700 that Wan 2.1 Dit speed increased from 12.6s/step to 9.17s/step (1.38x), with virtually identical video output. |
|
@coderfeli Per your guidance, applied the new Result: 0 of 5 groups landed. ASM hash gate satisfied trivially.
Final ISA hash (the gate)Matches P7-A baseline byte-for-byte (because no edits landed). Verified on R9600D, Perf parity
Per-group commit listNone — no group's diff survived the verification step, so nothing was committed to Skill gap writeupDetailed root-cause analysis and proposed v2 fixes filed at sunway513/aiter-forge#48 follow-up: branch
Net of all of the above: the skill needs a v2 before it produces a non-empty diff on real gfx1201 kernels. Happy to drive that v2 if you want me to take it; otherwise the gap report has enough detail for whoever owns the skill. — Driver: aiter-forge skill PR sunway513/aiter-forge#48 |
…R#462 cleanup) Apply the same internal-types cleanup pattern as upstream FlyDSL PR ROCm#462 (coderfeli) to the gfx1201 flash_attn_func kernel. Replaces 110 of 116 raw MLIR dialect call sites (94.8% reduction) with FlyDSL public Numeric wrappers and Vector helpers. Changes by category: - 11 arith.constant + 18 arith.index sites -> fx.Int32/fx.Float32/fx.Index - 11 arith.index_cast sites -> fx.Index/fx.Int32/fx.Int64 wrappers - 14 arith.AddFOp/SubFOp/MulFOp/MaxNumFOp + 8 arith.AddIOp/MulIOp etc -> local _fadd/_fsub/_fmul/_fmax helpers that preserve fastmath flag via lowercase arith.{addf,subf,mulf} + arith.MaxNumFOp - 14 vector.load/store/extract/from_elements/bitcast/broadcast sites -> Vec.load / Vec(x).store / Vec(x)[i] / Vec.from_elements(...,dtype) - 4 _llvm.GEPOp/LoadOp/StoreOp sites -> buffer_ops.get_element_ptr + _pointer_load/_pointer_store helpers - 4 scf.IfOp + 1 scf.for_ + 5 scf.YieldOp sites -> Python `if cond:` and `for ... in range(0, upper, step, init=...)` natural form - 4 _fly.extract_aligned_pointer_as_index sites -> _extract_aligned_pointer local helper - arith.constant_vector x2 -> Vec.filled - arith.trunc_f / arith.truncf -> Vec(x).to(elem_dtype) + fx.Float32(x).to(...) - math_dialect.fma -> fmath.fma Worked around the 4 skill-v1 gaps documented in P7-D's audit: 1. fx wrapper-vs-raw mismatch: every wrapper that flows into raw MLIR ops is unwrapped via PR462's _to_raw helper (imported as _raw) 2. fx.Vec does not exist: use `from flydsl.expr.typing import Vector as Vec` 3. fastmath cannot be dropped: helpers preserve fastmath=fm_fast everywhere 4. scf.IfOp restructure is manual: the 16-yield CAUSAL mask block is unfolded SSA-style (PR462 lines 700-870 pattern) Two bytecode-preservation tricks beyond PR462's template: - Explicit arith.cmpi(slt) for q_in_bounds — fx.Index < operator defaults to unsigned compare, which would emit v_cmp_gt_u64_e64 vs baseline v_cmp_gt_i64_e64 and break ASM equality - aiter's dtype_to_elem_type returns raw ir.Type while Vec.make_type needs a Numeric class — added local _NUMERIC_MAP = {f32: fx.Float32, ...} Final ASM verification on R9600D (gfx1201) wan-best container: baseline ASM SHA256 (PR head affebbe, kernel f049714d): 4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997 refactor ASM SHA256: 4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997 -> BYTE-EQUAL. All 22 IR pipeline stages produce identical final ISA. Perf: 49.68 TFLOPS mid3 mean (5 runs, S=32768 H=12 D=128 bf16 noncausal), baseline 49.84, ratio 99.7% (within bf16 noise floor). Tests: op_tests/flydsl_tests/test_flydsl_fmha.py — 10 passed, 2 skipped (multi-GPU only, same as baseline). Includes both causal=False (Wan2.1 production hot path) and causal=True coverage. Lint: black + ruff both clean. Diffstat: 447 lines refactored (+255 / -192). Remaining 6 raw MLIR call sites are intentional and isolated to helper functions that map 1:1 to PR ROCm#462 upstream: arith.cmpi(slt) — preserve signed compare for ISA hash equality arith.MaxNumFOp — inside _fmax helper to preserve fastmath flag _llvm.LoadOp/StoreOp — inside _pointer_{load,store} helpers _memref.load — scalar element load with no Vec equivalent _fly.extract_aligned_pointer_as_index — inside _extract_aligned_pointer
|
@coderfeli Refactor done — kernel Hard gate cleared: final ISA SHA256 is byte-equal to the pre-refactor baseline. Op-call reduction: 116 → 6 raw MLIR dialect call sites (94.8%). Remaining 6 are intentional and isolated to helper functions (signed-cmpi to preserve baseline ISA, Perf: 49.68 TFLOPS mid3 mean over 5 runs (S=32768, H=12, D=128, bf16, non-causal) on R9600D gfx1201; baseline 49.84, ratio 99.7% — within bf16 noise. Tests: Lint: Two byte-preservation tweaks beyond your PR #462 template that I'd recommend documenting in the cleanup skill:
Also documenting the 4 skill-v1 gaps that P7-D's audit caught earlier today (wrapper-vs-raw mismatch, missing Diffstat: +255 / −192. Ready for your review. |
…2990) PR #2969 already builds a separate kernel variant per (num_heads, head_dim) tuple via the lru_cache key in `_get_kernel`. Wan2.2 shapes (H=24, D=128) JIT-compile and run on the existing kernel path with no kernel changes. Adds three Wan2.2 self-attention shapes to the bf16 correctness parametrize as a regression guard, and updates the wrapper docstring to document the validated production shapes for both Wan2.1 and Wan2.2. Validated on R9600D (gfx1201), HIP_VISIBLE_DEVICES=8, ROCm 7.2: Shape (B,S,H,D) bf16 | cos_min | FlyDSL | SDPA | speedup -------------------------+----------+---------+--------+--------- (1, 8190, 24, 128) 480p | 0.999985 | 17.1ms | 29.3ms | 1.72x (1, 18480, 24, 128) 720p | 0.999986 | 99.3ms | 144ms | 1.46x (1, 42840, 24, 128) 1080p| 0.999986 | 557ms | 758ms | 1.36x E2E Wan2.2 480p, 81 frames, 30 steps, single R9600D: - SDPA baseline (P7-N): 704.3 s, 7.20 s/step DiT, 15.29 GB peak - FlyDSL ON : <pending full run; smoke 5-step shows 6.6s/step> - 300 self-attn calls per 5 steps all hit FlyDSL (verified via dispatcher hit counter); cross-attention falls back to SDPA as expected (Lq != Lk rejected by wrapper). The 0.5% non-causal padding safety guard introduced in PR #2969 also clears for all three Wan2.2 shapes (worst case 720p: 80/18560 = 0.43%). No callers in this repo invoke flydsl_flash_attn_func directly; the expected dispatcher pattern (cast q/k from RoPE fp32 to v.dtype, then hand off to flydsl_flash_attn_func when self-attn) is documented in the wrapper docstring for downstream Wan2.2 integrators. Depends on #2969. Signed-off-by: Peng Sun <peng.sun@amd.com>
|
@sunway513 I changed the stream param from None, which means default stream, to torch.get current stream. We can merge after CI passed. |
Co-authored-by: Cursor <cursoragent@cursor.com>
|
@vivienfanghuagood @gyohuangxin can we add gfx1201 CI into aiter? Seems not checked now. |
…2985) PR #2959 introduced .github/scripts/install_triton.sh and added an "Install amd-triton" step to aiter-test.yaml that calls the script inside the docker container. The container's working directory is the PR's checkout, so any PR opened or last synced before #2959 landed on main does not contain the script and fails with: bash: line 1: ./.github/scripts/install_triton.sh: No such file ##[error]Process completed with exit code 127. This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10 shards failing), forcing authors to rebase just to get green CI. Fix: in the Install amd-triton step, fall back to fetching the script from the base ref via raw.githubusercontent.com when it is not present in the runner workspace. Workflow files for PR events always come from the base branch, so this stays consistent with the rest of the CI flow and adds no security boundary crossing. Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also call the script after a fresh git clone of the PR sha and would benefit from a similar fallback in a follow-up.
…2985) PR #2959 introduced .github/scripts/install_triton.sh and added an "Install amd-triton" step to aiter-test.yaml that calls the script inside the docker container. The container's working directory is the PR's checkout, so any PR opened or last synced before #2959 landed on main does not contain the script and fails with: bash: line 1: ./.github/scripts/install_triton.sh: No such file ##[error]Process completed with exit code 127. This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10 shards failing), forcing authors to rebase just to get green CI. Fix: in the Install amd-triton step, fall back to fetching the script from the base ref via raw.githubusercontent.com when it is not present in the runner workspace. Workflow files for PR events always come from the base branch, so this stays consistent with the rest of the CI flow and adds no security boundary crossing. Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also call the script after a fresh git clone of the PR sha and would benefit from a similar fallback in a follow-up.
ROCm/FlyDSL#472 Hi, it's submited! |
…2985) PR #2959 introduced .github/scripts/install_triton.sh and added an "Install amd-triton" step to aiter-test.yaml that calls the script inside the docker container. The container's working directory is the PR's checkout, so any PR opened or last synced before #2959 landed on main does not contain the script and fails with: bash: line 1: ./.github/scripts/install_triton.sh: No such file ##[error]Process completed with exit code 127. This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10 shards failing), forcing authors to rebase just to get green CI. Fix: in the Install amd-triton step, fall back to fetching the script from the base ref via raw.githubusercontent.com when it is not present in the runner workspace. Workflow files for PR events always come from the base branch, so this stays consistent with the rest of the CI flow and adds no security boundary crossing. Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also call the script after a fresh git clone of the PR sha and would benefit from a similar fallback in a follow-up.
Summary
Adds
aiter.ops.flydsl.flydsl_flash_attn_func, a self-attention kernel for gfx1201 / RDNA4 (Radeon AI PRO R9700 / R9600D / RX 9070 XT). Closes the gap that AITER had no FlyDSL FMHA backend for the RDNA4 line.The kernel itself is upstream FlyDSL's gfx1201 best variant: WMMA 16x16x16 on wave32, BLOCK_N=32,
rocdl.exp2intrinsic, software-pipelined GEMM2, and overlapped V global load.Wrapper design
(num_heads, head_dim, causal, dtype)vialru_cache, so a hot path only pays compile cost once.seq_lenis rounded up to the next multiple ofBLOCK_M=128with zero-padded Q/K/V; padded Q rows are sliced off before returning. Callers never see the alignment requirement. This unblocks production seq_len values that aren't natural multiples of 128 (e.g. Wan2.1 1.3B'sS=32760).Lq != Lkraises with a clear message, so callers can fall back to SDPA / a dedicated cross-attn path.head_dim >= 64andhead_dim % 32 == 0(kernel WMMA tile alignment).Production validation
End-to-end wall time on Wan2.1-T2V-1.3B (R9600D 16-card box, 480x832, 81 frames, 30 sampling steps; self-attn shape
S=32760padded to32768,B=1, H=12, D=128, bf16):The exact same wrapper was used to drive 1800 self-attn calls per generation; cross-attn (Lq=32760, Lk=512, ~3% of e2e by measurement) intentionally remains on SDPA's flash backend.
Test plan
op_tests/flydsl_tests/test_flydsl_fmha.pycovers:B=1, S=32768, H=12, D=128(Wan2.1 padded)B=2, S=1024, H=8, D=128(sanity)B=1, S=32760, H=12, D=128(exercises the auto-padding path)"self-attention"messagehead_dimout-of-range (e.g. 48) is rejectedgfx1201device (skips elsewhere) and onflydslpackage availability.All cases pass on a real R9600D (gfx1201) box: cosine similarity vs SDPA min
0.999985, mean0.999993for bf16.black --checkandruff checkare clean.Files
aiter/ops/flydsl/fmha_kernels.py— public wrapper (149 lines).aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py— FlyDSL kernel module (copied from upstream).aiter/ops/flydsl/kernels/kernels_common.py—dtype_to_elem_typehelper added (additive only, used by other FlyDSL kernel families too).aiter/ops/flydsl/__init__.py— registerflydsl_flash_attn_func.op_tests/flydsl_tests/test_flydsl_fmha.py— new test module.