Add attention variants and backend guide#2
Merged
Conversation
Document all attention variants (MHA, PA, MLA, Unified, Sparse, etc.) with backend support matrices, data type coverage, decision trees for choosing the right variant, and practical API examples. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
sunway513
pushed a commit
that referenced
this pull request
Mar 22, 2026
4 tasks
sunway513
added a commit
that referenced
this pull request
Apr 30, 2026
Wrapper-level safety guard for the padded-softmax bug raised by Copilot inline comment #2 on PR ROCm#2969. Padded K/V tokens produce QK^T = 0 but exp(0) = 1 still contributes to the softmax denominator and silently scales the output for non-causal attention. Causal mode masks padded positions so it is unaffected. Empirical RCA at aiter-forge-baselines/2969_padded_softmax_rca.md: - Wan2.1 production (S_real=32760, S_pad=32768, ratio=0.024%): cos_min 0.999992, max_abs 0.0008 — safe, indistinguishable from bf16 noise floor. - 50% padding worst case: rel_err 37.3%, max_abs 0.281 — silent output scaling, would corrupt downstream. Implements option (d) from the RCA decision doc (signed off by Peng): hybrid threshold. Non-causal calls with n_pad/seq_len_pad > 0.005 are rejected with a ValueError that points the caller at the three valid remediations (causal=True, pre-pad to multiple of 128, or use a masking-aware kernel). Threshold rationale: 0.5% is the bf16 mantissa precision floor (~0.4%, 7 mantissa bits) plus 1 bit of margin. Production Wan2.1 (0.024%) clears it by 20x, so the hot path stays open while the silent-disaster worst case is closed. Tests added (op_tests/flydsl_tests/test_flydsl_fmha.py): - test_flydsl_fmha_rejects_excessive_padding: B=1, S_real=129 (S_pad=256, 49.6% pad), causal=False — must raise ValueError with "0.5% safety threshold" substring. - test_flydsl_fmha_allows_tight_padding: Wan2.1 case S_real=32760, causal=False — must succeed and match SDPA reference (cos_min >= 0.9999). Regression guard for the production hot path. Validation on R9600D (gfx1201) inside wan-best container, HIP_VISIBLE_DEVICES=4: 10 passed, 2 skipped (multi-GPU only). black --check + ruff check both clean on touched files. Kernel file aiter/ops/flydsl/kernels/flash_attn_func_gfx1201.py is intentionally untouched — refactor is in a parallel branch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Highlights
Test plan
🤖 Generated with Claude Code