Eliminate Legacy MHA Unfused path from ONNX Attention; unify on 3-tier dispatch with causal alignment fix#27992
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.
Changes:
- Add
softcapplumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax. - Implement MEA decode by concatenating past+new KV into the present buffer (via
LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs. - Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/llm/attention.cc | Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging. |
| onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax. |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Adds softcap to AttentionParameters. |
| onnxruntime/test/providers/cpu/llm/attention_op_test.cc | Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var). |
| onnxruntime/test/python/transformers/test_onnx_attention/common.py | Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py | Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py | Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment. |
| onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc | Updates exclusions/comments for attention backend tests based on new softcap/decode behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
left a comment
There was a problem hiding this comment.
Probably no but check whether we are not breaking graph at any point. For example, #27484
Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.
c74bbaa to
e5f74e9
Compare
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
e5f74e9 to
5afc98e
Compare
tianleiwu
left a comment
There was a problem hiding this comment.
APPROVE
The implementation is sound, addresses the prior output_qk ordering concern (resolved that thread), and provides comprehensive test coverage across MEA decode, unfused softcap, spec-correct mask→softcap ordering, and the NaN fix. No high-severity issues found.
Highlights
output_qkcopy correctly moved before bias/softcap mutations — prior CHANGES_REQUESTED concern resolved.- MEA decode path has well-managed scratch buffer lifetimes;
kv_is_bsnhtracking correctly propagates through GQA expansion. mask_filter_valuecap (-1e+30f) is correctly scoped to MEA only and the math is verified.- Test coverage is thorough: C++ and Python, MEA/unfused, prompt/decode, fp16/bf16/fp32, various mask dims.
Suggestions (non-blocking)
- Causal mask ordering with softcap: In the
has_softcap && has_biasbranch,ComputeSoftmaxappliesis_unidirectionalcausal masking after softcap. Strict ONNX spec folds causal into the mask before softcap. This matches the CPU reference and tests pass, but a TODO comment would help track future spec alignment. - Defensive assertion: The
has_softcap && has_biasbranch bypassesuse_raw_attention_maskhandling. Safe today (ONNX domain converts all masks todata.attention_biasbefore unfused), but addingassert(!use_raw_attention_mask)guards against future contrib softcap use wheremask_indexcould coexist. - Bool mask test filters:
test_attention_4d_attn_mask_bool_cudaand_4d_cudaare still excluded with "may work now" TODOs. Consider removing or filing a follow-up issue.
a8f10de to
0a85503
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
0a85503 to
a6545be
Compare
Multi-agent review summaryI ran this PR through four independent reviewers (gpt-5.3-codex, gpt-5.5, claude-sonnet-4.6, claude-opus-4.7-high). There is strong consensus on one Critical bug plus several Major correctness/structural issues. Findings consolidated below; nits omitted unless they hide a bug. 🔴 CriticalC-1. Unified unfused kernel silently produces lower-right when caller asks for upper-left under TensorScatter / external-cache + causal
const int past = (seqlens_k != nullptr) ? (kv_end - q_sequence_length) : past_kv_length;When Concrete failure (matrix row #4):
Reachable today via: Three of four reviewers flagged this independently. The PR's response (deleting 20 TensorScatter Suggested fix: use 🟠 MajorM-1. TensorScatter
The ONNX spec does not declare Fix: restore the cases after fixing C-1, OR explicitly reject the combination in dispatch with M-2.
LaunchConcatNewToPastKV(... pk, pk, nk, nk, out_k, out_k, ...); // K
LaunchConcatNewToPastKV(... pv, pv, nv, nv, out_v, out_v, ...); // VEach call's K-block and V-block both write to the same destination. Correct only as a "benign data race" (same value, same address). Two issues:
Fix: add a M-3.
For TensorScatter / external cache, positions Fix: mask M-4. Possible OOB read in attn_mask handling for
The legacy path explicitly composed nonpad-bias + attn_mask into a M-5. Workspace blow-up: unified unfused always allocates two FP32-sized score matrices
const size_t qk_bytes = AlignTo(SafeInt<size_t>(elems) * sizeof(float), kAlign);
const size_t softmax_bytes = AlignTo(SafeInt<size_t>(elems) * sizeof(float), kAlign);For Fix: keep a memory-threshold fallback or fail with a clear diagnostic before allocating multi-GB scratch. At minimum, document the regression envelope. No benchmark evidence accompanied the consolidation. M-6. Eligibility predicate for
Same condition ( 🟡 ReadabilityR-1.
A reader will hunt for a R-2. The 6-line "ONNX spec: is_causal means upper-left…" comment block is duplicated verbatim in two MEA branches — R-3. R-4. R-5. New causal alignment tests are CUDA-only ( ✅ Things that look right
Recommended order of operations
The single must-fix before merge is C-1 — it produces silent wrong numbers on a path that the PR explicitly claims to make correct. — Synthesized from reviews by |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c6b9a18 to
c7c327a
Compare
…r dispatch Refactor ONNX Attention CUDA dispatch to a clean 3-tier cascade: Flash → Memory-Efficient (CUTLASS) → Unified Unfused. - Remove RunUnfusedAttention (Legacy MHA Unfused path) - Rename gqa_unfused_attention.cu → unfused_attention.cu as shared kernel - Fix causal alignment: upper-left (no past) vs lower-right (with past) per ONNX spec - Add causal_from_top_left flag to MEA (CUTLASS supports both alignments) - Add Flash guard: block causal + cross-attention + no past (no upper-left support) - Fix sliding-window per-batch past derivation when seqlens_k varies - Cap ScaledCopyQkKernel grid with grid-stride loop - Fix v_head_size reshape in test infrastructure for H!=H_v - Fully qualify kCutlassSafeMaskFilterValue namespace - Remove dead mask functions from attention_mask_impl - Add 4 C++ causal alignment tests - Fix Python TensorScatter tests (is_causal=0 for decode per spec) - Remove test_attention_4d_gqa_with_past_and_present_fp16_cuda filter (now passes) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…aming Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…icate, readability fixes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…LEMENTED Per ONNX spec, is_causal without past_key means upper-left alignment which is meaningless for decode (q[0] only sees kv[0]). Reject explicitly rather than producing silent wrong results through the unfused kernel path. Add negative test verifying the guard fires correctly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 4 causal alignment tests (UpperLeft, UpperLeftSmallHead, DecodeWithPastLowerRight, SquareNoPast) were CUDA-only. CPU EP handles them correctly — enable CPU execution. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
c679669 to
c8147e4
Compare
Motivation
Eliminate the legacy MHA Unfused path (
QkvToContextinattention_impl.cu) from the ONNX standard Attention op, simplifying the CUDA dispatch to a clean 3-tier cascade.Design
The legacy
RunUnfusedAttentionwrapper (which called contrib opsQkvToContext) is deleted. The contrib MHA op is unaffected.Key Behavior Changes
ScaledCopyQkKernelTesting
All existing tests pass (40 C++ attention tests, 215 Python parametrized cases) plus new coverage for causal alignment on CPU EP and softcap ordering verification.
Closes #27880. Related: #27516, #28198.