Add Memory Efficient Attention decode support and tests for ONNX Attention#27851
Closed
titaiwangms wants to merge 8 commits intomainfrom
Closed
Add Memory Efficient Attention decode support and tests for ONNX Attention#27851titaiwangms wants to merge 8 commits intomainfrom
titaiwangms wants to merge 8 commits intomainfrom
Conversation
Add 5 new test classes that exercise the Memory Efficient Attention (MEA)
kernel path during the decode phase (with past KV cache):
1. TestONNXAttentionMemoryEfficientGQA.test_gqa_past_memory_efficient
- GQA + MEA + Decode: the critical missing test case
2. TestONNXAttentionPaddingMaskMemoryEfficientGQA.test_gqa_past_padding_mea
- GQA + MEA + Decode + Bool Padding Mask
3. TestONNXAttentionGQAFloatMaskDecode.test_gqa_past_float_mask_4d
- GQA + MEA + Decode + Float Mask (was a HARD ERROR before code fix)
4. TestONNXAttentionMHAPastMEA.test_mha_past_mea
- MHA + MEA + Decode (explicit MEA path via ORT_DISABLE_FLASH_ATTENTION=1)
5. TestONNXAttentionMemoryEfficientGQABF16.test_gqa_past_memory_efficient_bf16
- BF16 + MEA + Decode
All tests follow the existing patterns: they reuse the same parity check
functions (parity_check_gqa_past, parity_check_gqa_past_with_padding,
parity_check_mha_past) and test case generators (gqa_past_test_cases,
gqa_past_padding_test_cases, mha_past_test_cases), forcing the MEA kernel
path via @patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}).
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Enable Memory Efficient Attention (cutlass FMHA) to handle decode steps with past_key/past_value, previously restricted to Flash only. Changes: - Add LaunchConcatNewToPastKV before MEA dispatch to concatenate past_key+K into present_key (and past_value+V into present_value) following the same pattern as the Flash decode path - Remove past_key==nullptr eligibility check from mea_eligible - Track kv_is_bsnh separately from is_bsnh since present buffers are always BNSH after concat; pass kv_is_bsnh to LaunchUngroup and MEA params for correct stride computation - Set present_kv_already_populated=true after concat to skip redundant post-attention present_key/value copy - Enforce head_size==v_head_size for MEA decode (LaunchConcatNewToPastKV uses a single head_size parameter) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (16a065d8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
When past_key is present and head_size != v_head_size, LaunchConcatNewToPastKV cannot handle the differing sizes (single head_size parameter). Previously this hit an ORT_ENFORCE crash inside RunMemoryEfficientAttention instead of gracefully falling back to the unfused attention path. Add eligibility guard in ComputeInternal so MEA is skipped for this configuration, allowing the unfused fallback to handle it. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (16a065d8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Medium findings from code review: - Use separate v_head_size for V buffer allocation and transpose instead of reusing head_size. The ORT_ENFORCE guarantees equality, but this matches Flash's defensive style and is correct if the constraint is ever relaxed. - Add safety comment documenting that uninitialized present-buffer positions with bool masks are safe because MEA's additive bias drives those positions to near-zero softmax weights. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (16a065d8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. Rename TestONNXAttentionGQAFloatMaskDecode to TestONNXAttentionMemoryEfficientGQAFloatMaskDecode for searchability (all MEA test classes now contain 'MemoryEfficient' or 'MEA') 2. Add present_k/v verification to float mask decode test — now checks that concatenated KV buffers match reference, not just output 3. Add comment explaining std=0.2 scaling (keeps fp16 numerically stable) 4. Add TestONNXAttentionGQA4DBNSHMEA — exercises 4D BNSH transpose logic through the MEA decode path (use_4d_bnsh=True) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
When bool masks produce variable per-batch past_seq_lens, the concat kernel leaves some positions unwritten. MEA reads all positions (unlike Flash which bounds reads via seqlens_k). Zero the present buffers first so unwritten positions contain 0.0 instead of potentially NaN-containing uninitialized memory. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (16a065d8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add TestONNXAttentionMHAAsymmetricHeadSize to verify that MEA gracefully falls back to unfused attention when head_size != v_head_size with past_key present (decode phase). Without the eligibility guard in ComputeInternal, this would crash with ORT_ENFORCE in LaunchConcatNewToPastKV. To support this test, add v_head_size field to AttentionConfig (defaults to 0 = same as head_size) and propagate it through the ONNX graph builder and io_binding helpers for V-related shapes (V input, past_value, present_value, Y output). All existing tests are unaffected since they don't set v_head_size. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replace `config.v_head_size if config.v_head_size else config.head_size` with `config.v_head_size or config.head_size` per ruff FURB110. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This was referenced Mar 27, 2026
titaiwangms
added a commit
that referenced
this pull request
Apr 4, 2026
…r fully-masked batches (#27831) Description: ### Summary Fixes three issues in the CUDA ONNX Attention operator and improves spec compliance: 1. min_bias_align crash on SM<80: The alignment check for Memory Efficient Attention (MEA) bias used 4*sizeof(T) (bytes), but the check is against element counts. Fixed to 4 elements, matching CUTLASS kMinimumAlignment. This prevented valid MEA dispatch on SM<80. 2. MEA NaN for fully-masked batches: When nonpad_kv_seqlen=0, CUTLASS MEA computes 1/s_prime where s_prime=0, producing NaN. Added ZeroOutputForFullyMaskedBatches kernel (MEA path only) to zero output for these batches. Uses int64_t for element count to prevent overflow at large context lengths. 3. Flash rejects attn_mask for spec compliance: Flash Attention's paged KV cache produces spec-divergent present_key/present_value layout when used with attn_mask + past_key. Flash now requires attn_mask == nullptr — cases with bool mask + past_key fall to the unfused runner which handles them spec-correctly. Removed ~137 lines of dead code (ConvertMaskToSeqlensKernel, LaunchConvertMaskToFlashSeqlensK) no longer needed after this change. ### Known limitation - GQA + bool attn_mask + past_key currently has no runner (Flash rejected, unfused doesn't support GQA, MEA blocked by past_key). Tracked via TODO — PR #27851 (MEA with past_key support) will close this gap. ### Related - Issue #27885: Flash Attention bool attn_mask semantic divergence (root cause documented) - PR #27851: MEA with past_key support (will close GQA gap) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Contributor
Author
|
Replaced by #27992 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces support for asymmetric Q/V head sizes in the CUDA Memory Efficient Attention (MEA) kernel and updates the associated Python test utilities to handle cases where the value head size (
v_head_size) differs from the query/key head size (head_size). The changes ensure correct handling of tensor shapes and memory layouts throughout the MEA code path, especially during decoding with past key/value caches.CUDA MEA kernel improvements:
head_sizeandv_head_sizeare handled separately throughout the MEA path, including in the decode (past key/value) case. This involves updating buffer allocations, transpositions, and kernel launches to use the correctv_head_sizewhere appropriate. [1] [2]head_size != v_head_sizeand ensuring the correct memory layout (BSNHvsBNSH). [1] [2] [3] [4] [5] [6] [7]head_size != v_head_sizeduring decode, as required by the concat kernel.Python test utilities and shape handling:
AttentionConfigclass and all related test code to accept and propagate an explicitv_head_size, defaulting tohead_sizefor backward compatibility. All input/output tensor shapes, cache shapes, and present key/value shapes now correctly reflectv_head_sizewhere applicable. [1] [2] [3] [4] [5] [6] [7] [8] [9]Documentation and code clarity:
These changes collectively ensure robust support for models that use different head sizes for query/key and value projections, both in the CUDA kernel and in test coverage.