Skip to content

[FLYDSL] Add gfx1201 (RDNA4) flash_attn_func backend#2969

Merged
coderfeli merged 8 commits into
ROCm:mainfrom
sunway513:feat/flydsl-fmha-gfx1201
May 3, 2026
Merged

[FLYDSL] Add gfx1201 (RDNA4) flash_attn_func backend#2969
coderfeli merged 8 commits into
ROCm:mainfrom
sunway513:feat/flydsl-fmha-gfx1201

Conversation

@sunway513
Copy link
Copy Markdown
Collaborator

Summary

Adds aiter.ops.flydsl.flydsl_flash_attn_func, a self-attention kernel for gfx1201 / RDNA4 (Radeon AI PRO R9700 / R9600D / RX 9070 XT). Closes the gap that AITER had no FlyDSL FMHA backend for the RDNA4 line.

The kernel itself is upstream FlyDSL's gfx1201 best variant: WMMA 16x16x16 on wave32, BLOCK_N=32, rocdl.exp2 intrinsic, software-pipelined GEMM2, and overlapped V global load.

Wrapper design

  • Build cache: kernels keyed by (num_heads, head_dim, causal, dtype) via lru_cache, so a hot path only pays compile cost once.
  • Auto-padding: seq_len is rounded up to the next multiple of BLOCK_M=128 with zero-padded Q/K/V; padded Q rows are sliced off before returning. Callers never see the alignment requirement. This unblocks production seq_len values that aren't natural multiples of 128 (e.g. Wan2.1 1.3B's S=32760).
  • Self-attention only: Lq != Lk raises with a clear message, so callers can fall back to SDPA / a dedicated cross-attn path.
  • Constraints: head_dim >= 64 and head_dim % 32 == 0 (kernel WMMA tile alignment).

Production validation

End-to-end wall time on Wan2.1-T2V-1.3B (R9600D 16-card box, 480x832, 81 frames, 30 sampling steps; self-attn shape S=32760 padded to 32768, B=1, H=12, D=128, bf16):

Backend e2e wall time kernel @ S=32768
PyTorch SDPA fallback 617.1 s / 10.29 min 29.0 TFLOPS
Triton FA (BM=256, BN=32, gfx1201-tuned) 471.0 s / 7.85 min ~38 TFLOPS
flydsl_flash_attn_func (this PR) 460.9 s / 7.68 min 49.4 TFLOPS
  • vs SDPA: 1.34x e2e (25.3% faster); 1.70x at kernel level.
  • vs Triton FA: 2.1% e2e improvement, with substantial micro-bench runway not yet exhausted (22-variant production-shape sweep showed the top 4 variants within 7 ms).

The exact same wrapper was used to drive 1800 self-attn calls per generation; cross-attn (Lq=32760, Lk=512, ~3% of e2e by measurement) intentionally remains on SDPA's flash backend.

Test plan

op_tests/flydsl_tests/test_flydsl_fmha.py covers:

  • Correctness vs SDPA on:
    • Aligned production shape B=1, S=32768, H=12, D=128 (Wan2.1 padded)
    • Smaller aligned B=2, S=1024, H=8, D=128 (sanity)
    • Unaligned B=1, S=32760, H=12, D=128 (exercises the auto-padding path)
  • Argument validation:
    • Cross-attn raises with "self-attention" message
    • head_dim out-of-range (e.g. 48) is rejected
    • Mismatched dtype across q/k/v is rejected
  • Test is gated on gfx1201 device (skips elsewhere) and on flydsl package availability.

All cases pass on a real R9600D (gfx1201) box: cosine similarity vs SDPA min 0.999985, mean 0.999993 for bf16. black --check and ruff check are clean.

Files

  • aiter/ops/flydsl/fmha_kernels.py — public wrapper (149 lines).
  • aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py — FlyDSL kernel module (copied from upstream).
  • aiter/ops/flydsl/kernels/kernels_common.pydtype_to_elem_type helper added (additive only, used by other FlyDSL kernel families too).
  • aiter/ops/flydsl/__init__.py — register flydsl_flash_attn_func.
  • op_tests/flydsl_tests/test_flydsl_fmha.py — new test module.

Adds aiter.ops.flydsl.flydsl_flash_attn_func, a self-attention kernel
optimized for gfx1201 (Radeon AI PRO R9700 / R9600D / RX 9070 XT). The
kernel uses WMMA 16x16x16 on wave32, BLOCK_N=32, rocdl.exp2 intrinsic,
software-pipelined GEMM2, and overlapped V global load.

The Python wrapper:
  - Caches built kernels by (num_heads, head_dim, causal, dtype).
  - Auto-pads seq_len up to a multiple of BLOCK_M=128 with zero-padded
    K/V/Q; the padded query rows are sliced off before returning, so
    callers never see the alignment requirement.
  - Restricts to self-attention (Lq == Lk); cross-attention raises.
  - Validates head_dim >= 64 and head_dim % 32 == 0.

Production validation on Wan2.1-T2V-1.3B (480x832, 81 frames, 30 steps;
self-attn shape S=32760 padded to 32768, B=1, H=12, D=128, bf16):

  Backend           e2e wall time      kernel TFLOPS @ S=32768
  ----------------  -----------------  -----------------------
  SDPA fallback     617.1 s / 10.29m   29.0
  Triton FA (BM=256/BN=32, gfx1201)
                    471.0 s /  7.85m   ~38
  flydsl_flash_attn 460.9 s /  7.68m   49.4

  vs SDPA: 1.34x e2e speedup (25.3% faster); 1.70x at kernel level.
  vs Triton FA: 2.1% e2e improvement.

Tests in op_tests/flydsl_tests/test_flydsl_fmha.py cover correctness vs
SDPA on aligned (B=1 S=32768 H=12 D=128 and B=2 S=1024 H=8 D=128) and
unaligned (B=1 S=32760 H=12 D=128, exercises the padding path) shapes,
plus argument validation (cross-attn rejection, head_dim, dtype).

The kernel itself is upstream FlyDSL's gfx1201 best variant (BLOCK_N=32
+ rocdl.exp2 + pipelined GEMM2 + overlapped V load); the Python file is
copied into aiter.ops.flydsl.kernels/ following the existing convention
for FlyDSL-backed kernels in this directory. dtype_to_elem_type is added
to kernels_common for symmetry with the other FlyDSL kernel families.
@sunway513 sunway513 requested review from a team and Copilot April 29, 2026 18:31
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2969 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new FlyDSL-backed FlashAttention (FMHA) implementation targeting gfx1201 / RDNA4, exposing it as a public API and adding test coverage for correctness and input validation.

Changes:

  • Introduces aiter.ops.flydsl.flydsl_flash_attn_func wrapper with kernel build caching and seq_len auto-padding.
  • Adds the gfx1201 FlyDSL kernel module and a shared dtype_to_elem_type helper.
  • Registers the new API in aiter.ops.flydsl and adds a gfx1201-gated pytest module.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
op_tests/flydsl_tests/test_flydsl_fmha.py New tests for FMHA correctness vs SDPA and argument validation (gfx1201 + FlyDSL gated).
aiter/ops/flydsl/kernels/kernels_common.py Adds dtype_to_elem_type() utility for FlyDSL MLIR scalar type selection.
aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py Adds upstream FlyDSL gfx1201 FlashAttention kernel implementation.
aiter/ops/flydsl/fmha_kernels.py New public wrapper providing caching, padding, validation, and BSHD interface.
aiter/ops/flydsl/__init__.py Exposes flydsl_flash_attn_func when FlyDSL is available and version-compatible.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/flydsl/fmha_kernels.py Outdated
Comment thread aiter/ops/flydsl/fmha_kernels.py Outdated
Comment thread aiter/ops/flydsl/fmha_kernels.py Outdated
Comment thread op_tests/flydsl_tests/test_flydsl_fmha.py Outdated
Comment thread op_tests/flydsl_tests/test_flydsl_fmha.py
Comment thread aiter/ops/flydsl/fmha_kernels.py
@coderfeli
Copy link
Copy Markdown
Collaborator

coderfeli commented Apr 30, 2026

@sunway513 large quantities of raw mlir types are used in this kernel including arith, scf, memref, vector. Though these are not forbidden but make the codes very difficult to read. I suggest to refactor a little bit. ROCm/FlyDSL#462 You can follow this PR and this skill https://github.com/ROCm/FlyDSL/blob/main/.claude/skills/flydsl-internal-types-cleanup/SKILL.md to do it. It will not affect perf if expected.

Addresses 5 of 6 Copilot inline review comments on PR ROCm#2969 (the 6th —
padded softmax bug in the kernel — is handled separately):

  #1 (multi-GPU device context): wrap _get_kernel() build and exe()
     launch in 'with torch.cuda.device(q.device.index):' so callers
     whose current device differs from q.device get the kernel built
     and launched on the correct device/stream.

  #3 (docstring mismatch): module docstring's cache-key list now
     includes waves_per_eu and daz to match _get_kernel signature.

  #4 (unused import): drop the unused 'flydsl =' assignment from
     pytest.importorskip in test file (silences ruff F841).

  #5 (dtype/causal coverage): add two correctness tests —
       - test_flydsl_fmha_correctness_f16 (Wan2.1-shape, fp16, non-causal)
       - test_flydsl_fmha_correctness_causal_small (small bf16, causal)

  #6 (device validation): at the top of flydsl_flash_attn_func, raise
     ValueError when q/k/v are not on the same device, and when the
     active device's gcnArchName is not gfx1201.

Bonus tests:
  - test_flydsl_fmha_correctness_multi_device: runs the kernel on
    cuda:1 while current device is cuda:0 in a subprocess, exercising
    the device-context wrap (#1). xfails cleanly if the FlyDSL runtime
    pins to device 0 internally.
  - test_flydsl_fmha_rejects_device_mismatch: validates the same-device
    guard from #6 directly (q on cuda:0, k/v on cuda:1).

Verified with black, ruff, and pytest (10/10 PASS) on R9600D MI gfx1201.
Wrapper-level safety guard for the padded-softmax bug raised by Copilot
inline comment #2 on PR ROCm#2969. Padded K/V tokens produce QK^T = 0 but
exp(0) = 1 still contributes to the softmax denominator and silently
scales the output for non-causal attention. Causal mode masks padded
positions so it is unaffected.

Empirical RCA at aiter-forge-baselines/2969_padded_softmax_rca.md:
  - Wan2.1 production (S_real=32760, S_pad=32768, ratio=0.024%):
    cos_min 0.999992, max_abs 0.0008 — safe, indistinguishable from
    bf16 noise floor.
  - 50% padding worst case: rel_err 37.3%, max_abs 0.281 — silent
    output scaling, would corrupt downstream.

Implements option (d) from the RCA decision doc (signed off by Peng):
hybrid threshold. Non-causal calls with n_pad/seq_len_pad > 0.005 are
rejected with a ValueError that points the caller at the three valid
remediations (causal=True, pre-pad to multiple of 128, or use a
masking-aware kernel).

Threshold rationale: 0.5% is the bf16 mantissa precision floor (~0.4%,
7 mantissa bits) plus 1 bit of margin. Production Wan2.1 (0.024%)
clears it by 20x, so the hot path stays open while the silent-disaster
worst case is closed.

Tests added (op_tests/flydsl_tests/test_flydsl_fmha.py):
  - test_flydsl_fmha_rejects_excessive_padding: B=1, S_real=129
    (S_pad=256, 49.6% pad), causal=False — must raise ValueError with
    "0.5% safety threshold" substring.
  - test_flydsl_fmha_allows_tight_padding: Wan2.1 case S_real=32760,
    causal=False — must succeed and match SDPA reference (cos_min
    >= 0.9999). Regression guard for the production hot path.

Validation on R9600D (gfx1201) inside wan-best container,
HIP_VISIBLE_DEVICES=4: 10 passed, 2 skipped (multi-GPU only).
black --check + ruff check both clean on touched files.

Kernel file aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py is
intentionally untouched — refactor is in a parallel branch.
@vivienfanghuagood
Copy link
Copy Markdown
Contributor

It has been verified on the 9700 that Wan 2.1 Dit speed increased from 12.6s/step to 9.17s/step (1.38x), with virtually identical video output.

@sunway513
Copy link
Copy Markdown
Collaborator Author

@coderfeli Per your guidance, applied the new flydsl-internal-types-cleanup skill (sunway513/aiter-forge#48, source SHA 11dc8b34) to aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py at PR head affebbe6. This is the first real-world dogfood run of the skill, so reporting the result honestly rather than force-fitting a pretty diff.

Result: 0 of 5 groups landed. ASM hash gate satisfied trivially.

Group Status Why
G1 constants REVERTED fx.Int32(...) / fx.Index(...) returns a Numeric wrapper, not raw ir.Value. Downstream arith.AndIOp / scf.for_ / arith.AddIOp etc. reject the wrapper at construction time (ValueError: Operand N must be a Value).
G2 index_casts REVERTED Same root cause. fx.Index(seq_len) consumed by scf.for_(_, kv_upper, _) (because kv_upper = seq_len_v) blows up the same way.
G3 vectors SKIPPED fx.Vec does not exist in FlyDSL 0.1.5 (hasattr(fx, "Vec") == False). The G3 regex also targets literal T.vec(n, T.i32), but PR #2969 uses aliases (v8i16_type, v8f16_type, _v4i32) — zero matches even if fx.Vec existed.
G4 trunc/arith/select SKIPPED arith.addf / arith.mulf regex hits 0 (kernel uses uppercase Op-form arith.AddFOp(..., fastmath=fm_fast) everywhere — and per skill exceptions, fastmath sites must stay low-level). The single arith.trunc_f site uses a variable element type. The 3 arith.select sites have cond produced by raw arith.cmpi (not ArithValue), so the rewrite is invalid by the skill's own caveat. Net yield: 0.
G5 control flow SKIPPED 4 scf.IfOp sites with has_else=True and yielded values; per skill, manual transformation only. Punted.

Final ISA hash (the gate)

4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997

Matches P7-A baseline byte-for-byte (because no edits landed). Verified on R9600D, wan-best container, FlyDSL 0.1.5, HIP_VISIBLE_DEVICES=0, with FLYDSL_DUMP_IR=1 FLYDSL_RUNTIME_ENABLE_CACHE=0.

Perf parity

  • Kernel TFLOPS (B=1, S=32768, H=12, D=128, bf16, non-causal): MID3 49.78 (baseline 49.57) — within noise.
  • E2E Wan2.1-T2V-1.3B (480x832, 81 frames, 30 steps): 462.2 s total, 17.91 GB peak VRAM (baseline 459.9 s / 17.91 GB; limits ≤466 s / ≤18.5 GB). PASS.
  • Layer 2 pytest: 10/10 passed, 2 multi-device tests skipped (single-GPU isolation via HIP_VISIBLE_DEVICES=0). Same as P7-A baseline.
  • Layer 5 lint: black --check aiter/ops/flydsl/ clean, ruff check aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py clean. (Other files in the directory have pre-existing ruff warnings unrelated to this PR.)

Per-group commit list

None — no group's diff survived the verification step, so nothing was committed to feat/flydsl-fmha-gfx1201. P7-F's wrapper work (aiter/ops/flydsl/fmha_kernels.py) is untouched.

Skill gap writeup

Detailed root-cause analysis and proposed v2 fixes filed at sunway513/aiter-forge#48 follow-up: branch p7d/skill-gaps-2969, file learnings/tuning/p7d_g1_g4_skill_gap.md. Headline findings:

  1. G1/G2 must emit .ir_value in their replace_template (or skip when the value flows into raw MLIR ops). Today's regex is byte-equivalent in theory but kills the kernel in practice on any file mixing flydsl.expr with flydsl._mlir.dialects.
  2. G3 needs a min-FlyDSL-version declaration (fx.Vec not yet shipped) and an alias-aware regex for vector element types.
  3. G4 regex needs uppercase Op-form coverage (arith.AddFOp(, arith.MulFOp(, etc.) and a fastmath= skip-rule. Without these, the skill yields 0 on this gfx1201 kernel.
  4. arith.select rewrite needs a runtime cond-type guard, not just a template-time caveat.

Net of all of the above: the skill needs a v2 before it produces a non-empty diff on real gfx1201 kernels. Happy to drive that v2 if you want me to take it; otherwise the gap report has enough detail for whoever owns the skill.

Driver: aiter-forge skill PR sunway513/aiter-forge#48
Gap report: sunway513/aiter-forge p7d/skill-gaps-2969 branch, learnings/tuning/p7d_g1_g4_skill_gap.md

…R#462 cleanup)

Apply the same internal-types cleanup pattern as upstream FlyDSL PR ROCm#462
(coderfeli) to the gfx1201 flash_attn_func kernel. Replaces 110 of 116
raw MLIR dialect call sites (94.8% reduction) with FlyDSL public Numeric
wrappers and Vector helpers.

Changes by category:
- 11 arith.constant + 18 arith.index sites -> fx.Int32/fx.Float32/fx.Index
- 11 arith.index_cast sites -> fx.Index/fx.Int32/fx.Int64 wrappers
- 14 arith.AddFOp/SubFOp/MulFOp/MaxNumFOp + 8 arith.AddIOp/MulIOp etc
  -> local _fadd/_fsub/_fmul/_fmax helpers that preserve fastmath flag
  via lowercase arith.{addf,subf,mulf} + arith.MaxNumFOp
- 14 vector.load/store/extract/from_elements/bitcast/broadcast sites
  -> Vec.load / Vec(x).store / Vec(x)[i] / Vec.from_elements(...,dtype)
- 4 _llvm.GEPOp/LoadOp/StoreOp sites -> buffer_ops.get_element_ptr +
  _pointer_load/_pointer_store helpers
- 4 scf.IfOp + 1 scf.for_ + 5 scf.YieldOp sites -> Python `if cond:` and
  `for ... in range(0, upper, step, init=...)` natural form
- 4 _fly.extract_aligned_pointer_as_index sites -> _extract_aligned_pointer
  local helper
- arith.constant_vector x2 -> Vec.filled
- arith.trunc_f / arith.truncf -> Vec(x).to(elem_dtype) + fx.Float32(x).to(...)
- math_dialect.fma -> fmath.fma

Worked around the 4 skill-v1 gaps documented in P7-D's audit:
1. fx wrapper-vs-raw mismatch: every wrapper that flows into raw MLIR ops
   is unwrapped via PR462's _to_raw helper (imported as _raw)
2. fx.Vec does not exist: use `from flydsl.expr.typing import Vector as Vec`
3. fastmath cannot be dropped: helpers preserve fastmath=fm_fast everywhere
4. scf.IfOp restructure is manual: the 16-yield CAUSAL mask block is
   unfolded SSA-style (PR462 lines 700-870 pattern)

Two bytecode-preservation tricks beyond PR462's template:
- Explicit arith.cmpi(slt) for q_in_bounds — fx.Index < operator defaults
  to unsigned compare, which would emit v_cmp_gt_u64_e64 vs baseline
  v_cmp_gt_i64_e64 and break ASM equality
- aiter's dtype_to_elem_type returns raw ir.Type while Vec.make_type needs
  a Numeric class — added local _NUMERIC_MAP = {f32: fx.Float32, ...}

Final ASM verification on R9600D (gfx1201) wan-best container:
  baseline ASM SHA256 (PR head affebbe, kernel f049714d):
    4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997
  refactor ASM SHA256:
    4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997
  -> BYTE-EQUAL. All 22 IR pipeline stages produce identical final ISA.

Perf: 49.68 TFLOPS mid3 mean (5 runs, S=32768 H=12 D=128 bf16 noncausal),
baseline 49.84, ratio 99.7% (within bf16 noise floor).

Tests: op_tests/flydsl_tests/test_flydsl_fmha.py — 10 passed, 2 skipped
(multi-GPU only, same as baseline). Includes both causal=False (Wan2.1
production hot path) and causal=True coverage.

Lint: black + ruff both clean.

Diffstat: 447 lines refactored (+255 / -192). Remaining 6 raw MLIR call
sites are intentional and isolated to helper functions that map 1:1 to
PR ROCm#462 upstream:
  arith.cmpi(slt)  — preserve signed compare for ISA hash equality
  arith.MaxNumFOp  — inside _fmax helper to preserve fastmath flag
  _llvm.LoadOp/StoreOp — inside _pointer_{load,store} helpers
  _memref.load     — scalar element load with no Vec equivalent
  _fly.extract_aligned_pointer_as_index — inside _extract_aligned_pointer
@sunway513
Copy link
Copy Markdown
Collaborator Author

@coderfeli Refactor done — kernel flash_attn_func_gfx1201.py now uses the same FlyDSL internal-types pattern as your upstream PR #462. Pushed as b02162e0.

Hard gate cleared: final ISA SHA256 is byte-equal to the pre-refactor baseline.

baseline: 4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997
refactor: 4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997

Op-call reduction: 116 → 6 raw MLIR dialect call sites (94.8%). Remaining 6 are intentional and isolated to helper functions (signed-cmpi to preserve baseline ISA, _fmax for fastmath flag, _pointer_{load,store} and _extract_aligned_pointer helpers, _memref.load for scalar element loads).

Perf: 49.68 TFLOPS mid3 mean over 5 runs (S=32768, H=12, D=128, bf16, non-causal) on R9600D gfx1201; baseline 49.84, ratio 99.7% — within bf16 noise.

Tests: op_tests/flydsl_tests/test_flydsl_fmha.py 10 passed, 2 skipped (multi-GPU only). Both causal=False (Wan2.1 production hot path) and causal=True (small-shape mask coverage) pass.

Lint: black --check and ruff check both clean.

Two byte-preservation tweaks beyond your PR #462 template that I'd recommend documenting in the cleanup skill:

  1. fx.Index < seq_len_v lowers to unsigned v_cmp_*_u64. Baseline uses arith.cmpi(slt, …) which lowers to signed v_cmp_*_i64. For non-negative offsets the result is identical, but the ISA differs. I kept an explicit arith.cmpi(arith.CmpIPredicate.slt, _raw(q_row), _raw(seq_len_v)) for the one site that gates Q OOB rows.
  2. aiter/ops/flydsl/kernels/kernels_common.py::dtype_to_elem_type returns a raw ir.Type (e.g. BF16Type), not a Numeric class. Vec.make_type(N, dtype) requires a Numeric. Added a local _NUMERIC_MAP in the kernel as the workaround. A future cleanup of kernels_common.py could change dtype_to_elem_type to return both (Numeric, ir.Type) similar to your PR Rtp llm refactor #462 upstream version.

Also documenting the 4 skill-v1 gaps that P7-D's audit caught earlier today (wrapper-vs-raw mismatch, missing fx.Vec, fastmath preservation, manual scf.IfOp SSA restructure for the 16-yield CAUSAL mask block) — all four worked around in this commit. Will feed back to the flydsl-internal-types-cleanup skill so v2 can handle them automatically.

Diffstat: +255 / −192. Ready for your review.

coderfeli
coderfeli previously approved these changes May 3, 2026
coderfeli pushed a commit that referenced this pull request May 3, 2026
…2990)

PR #2969 already builds a separate kernel variant per (num_heads, head_dim)
tuple via the lru_cache key in `_get_kernel`. Wan2.2 shapes (H=24, D=128)
JIT-compile and run on the existing kernel path with no kernel changes.

Adds three Wan2.2 self-attention shapes to the bf16 correctness
parametrize as a regression guard, and updates the wrapper docstring to
document the validated production shapes for both Wan2.1 and Wan2.2.

Validated on R9600D (gfx1201), HIP_VISIBLE_DEVICES=8, ROCm 7.2:

  Shape (B,S,H,D) bf16     | cos_min  | FlyDSL  | SDPA   | speedup
  -------------------------+----------+---------+--------+---------
  (1, 8190, 24, 128)  480p | 0.999985 | 17.1ms  | 29.3ms |  1.72x
  (1, 18480, 24, 128) 720p | 0.999986 | 99.3ms  | 144ms  |  1.46x
  (1, 42840, 24, 128) 1080p| 0.999986 | 557ms   | 758ms  |  1.36x

E2E Wan2.2 480p, 81 frames, 30 steps, single R9600D:
  - SDPA baseline (P7-N): 704.3 s, 7.20 s/step DiT, 15.29 GB peak
  - FlyDSL ON           : <pending full run; smoke 5-step shows 6.6s/step>
  - 300 self-attn calls per 5 steps all hit FlyDSL (verified via
    dispatcher hit counter); cross-attention falls back to SDPA as
    expected (Lq != Lk rejected by wrapper).

The 0.5% non-causal padding safety guard introduced in PR #2969 also
clears for all three Wan2.2 shapes (worst case 720p: 80/18560 = 0.43%).

No callers in this repo invoke flydsl_flash_attn_func directly; the
expected dispatcher pattern (cast q/k from RoPE fp32 to v.dtype, then
hand off to flydsl_flash_attn_func when self-attn) is documented in the
wrapper docstring for downstream Wan2.2 integrators.

Depends on #2969.

Signed-off-by: Peng Sun <peng.sun@amd.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

coderfeli commented May 3, 2026

@sunway513 I changed the stream param from None, which means default stream, to torch.get current stream. We can merge after CI passed.

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

@vivienfanghuagood @gyohuangxin can we add gfx1201 CI into aiter? Seems not checked now.

gyohuangxin pushed a commit that referenced this pull request May 3, 2026
…2985)

PR #2959 introduced .github/scripts/install_triton.sh and added an
"Install amd-triton" step to aiter-test.yaml that calls the script
inside the docker container. The container's working directory is the
PR's checkout, so any PR opened or last synced before #2959 landed on
main does not contain the script and fails with:

  bash: line 1: ./.github/scripts/install_triton.sh: No such file
  ##[error]Process completed with exit code 127.

This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10
shards failing), forcing authors to rebase just to get green CI.

Fix: in the Install amd-triton step, fall back to fetching the script
from the base ref via raw.githubusercontent.com when it is not present
in the runner workspace. Workflow files for PR events always come from
the base branch, so this stays consistent with the rest of the CI flow
and adds no security boundary crossing.

Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU
Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also
call the script after a fresh git clone of the PR sha and would
benefit from a similar fallback in a follow-up.
@coderfeli coderfeli merged commit d2454ad into ROCm:main May 3, 2026
44 of 45 checks passed
chun-wan pushed a commit that referenced this pull request May 4, 2026
…2985)

PR #2959 introduced .github/scripts/install_triton.sh and added an
"Install amd-triton" step to aiter-test.yaml that calls the script
inside the docker container. The container's working directory is the
PR's checkout, so any PR opened or last synced before #2959 landed on
main does not contain the script and fails with:

  bash: line 1: ./.github/scripts/install_triton.sh: No such file
  ##[error]Process completed with exit code 127.

This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10
shards failing), forcing authors to rebase just to get green CI.

Fix: in the Install amd-triton step, fall back to fetching the script
from the base ref via raw.githubusercontent.com when it is not present
in the runner workspace. Workflow files for PR events always come from
the base branch, so this stays consistent with the rest of the CI flow
and adds no security boundary crossing.

Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU
Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also
call the script after a fresh git clone of the PR sha and would
benefit from a similar fallback in a follow-up.
@vivienfanghuagood
Copy link
Copy Markdown
Contributor

@vivienfanghuagood @gyohuangxin can we add gfx1201 CI into aiter? Seems not checked now.

ROCm/FlyDSL#472 Hi, it's submited!

Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
…2985)

PR #2959 introduced .github/scripts/install_triton.sh and added an
"Install amd-triton" step to aiter-test.yaml that calls the script
inside the docker container. The container's working directory is the
PR's checkout, so any PR opened or last synced before #2959 landed on
main does not contain the script and fails with:

  bash: line 1: ./.github/scripts/install_triton.sh: No such file
  ##[error]Process completed with exit code 127.

This blocks Standard Tests on every stale PR (e.g. #2969, all 9/10
shards failing), forcing authors to rebase just to get green CI.

Fix: in the Install amd-triton step, fall back to fetching the script
from the base ref via raw.githubusercontent.com when it is not present
in the runner workspace. Workflow files for PR events always come from
the base branch, so this stays consistent with the rest of the CI flow
and adds no security boundary crossing.

Applied symmetrically to the Standard Tests (1 GPU) and Multi-GPU
Tests (8 GPU) jobs. atom-test.yaml and sglang_downstream.yaml also
call the script after a fresh git clone of the PR sha and would
benefit from a similar fallback in a follow-up.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants