Skip to content

Add Memory Efficient Attention decode support and tests for ONNX Attention#27851

Closed
titaiwangms wants to merge 8 commits intomainfrom
feature/mea-decode-support
Closed

Add Memory Efficient Attention decode support and tests for ONNX Attention#27851
titaiwangms wants to merge 8 commits intomainfrom
feature/mea-decode-support

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

This pull request introduces support for asymmetric Q/V head sizes in the CUDA Memory Efficient Attention (MEA) kernel and updates the associated Python test utilities to handle cases where the value head size (v_head_size) differs from the query/key head size (head_size). The changes ensure correct handling of tensor shapes and memory layouts throughout the MEA code path, especially during decoding with past key/value caches.

CUDA MEA kernel improvements:

  • Added support for asymmetric Q/V head sizes by ensuring that head_size and v_head_size are handled separately throughout the MEA path, including in the decode (past key/value) case. This involves updating buffer allocations, transpositions, and kernel launches to use the correct v_head_size where appropriate. [1] [2]
  • Modified the decode path to concatenate past and new K/V tensors into present buffers, correctly handling the case where head_size != v_head_size and ensuring the correct memory layout (BSNH vs BNSH). [1] [2] [3] [4] [5] [6] [7]
  • Updated MEA eligibility logic to fall back to the unfused path if head_size != v_head_size during decode, as required by the concat kernel.

Python test utilities and shape handling:

  • Updated the AttentionConfig class and all related test code to accept and propagate an explicit v_head_size, defaulting to head_size for backward compatibility. All input/output tensor shapes, cache shapes, and present key/value shapes now correctly reflect v_head_size where applicable. [1] [2] [3] [4] [5] [6] [7] [8] [9]

Documentation and code clarity:

  • Improved comments in the MEA dispatch path to clarify handling of decode, prompt, and external cache scenarios, and updated notes on eligibility and kernel requirements.

These changes collectively ensure robust support for models that use different head sizes for query/key and value projections, both in the CUDA kernel and in test coverage.

titaiwangms and others added 8 commits March 25, 2026 23:07
Add 5 new test classes that exercise the Memory Efficient Attention (MEA)
kernel path during the decode phase (with past KV cache):

1. TestONNXAttentionMemoryEfficientGQA.test_gqa_past_memory_efficient
   - GQA + MEA + Decode: the critical missing test case
2. TestONNXAttentionPaddingMaskMemoryEfficientGQA.test_gqa_past_padding_mea
   - GQA + MEA + Decode + Bool Padding Mask
3. TestONNXAttentionGQAFloatMaskDecode.test_gqa_past_float_mask_4d
   - GQA + MEA + Decode + Float Mask (was a HARD ERROR before code fix)
4. TestONNXAttentionMHAPastMEA.test_mha_past_mea
   - MHA + MEA + Decode (explicit MEA path via ORT_DISABLE_FLASH_ATTENTION=1)
5. TestONNXAttentionMemoryEfficientGQABF16.test_gqa_past_memory_efficient_bf16
   - BF16 + MEA + Decode

All tests follow the existing patterns: they reuse the same parity check
functions (parity_check_gqa_past, parity_check_gqa_past_with_padding,
parity_check_mha_past) and test case generators (gqa_past_test_cases,
gqa_past_padding_test_cases, mha_past_test_cases), forcing the MEA kernel
path via @patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Enable Memory Efficient Attention (cutlass FMHA) to handle decode
steps with past_key/past_value, previously restricted to Flash only.

Changes:
- Add LaunchConcatNewToPastKV before MEA dispatch to concatenate
  past_key+K into present_key (and past_value+V into present_value)
  following the same pattern as the Flash decode path
- Remove past_key==nullptr eligibility check from mea_eligible
- Track kv_is_bsnh separately from is_bsnh since present buffers are
  always BNSH after concat; pass kv_is_bsnh to LaunchUngroup and MEA
  params for correct stride computation
- Set present_kv_already_populated=true after concat to skip redundant
  post-attention present_key/value copy
- Enforce head_size==v_head_size for MEA decode (LaunchConcatNewToPastKV
  uses a single head_size parameter)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (16a065d8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
When past_key is present and head_size != v_head_size,
LaunchConcatNewToPastKV cannot handle the differing sizes (single
head_size parameter). Previously this hit an ORT_ENFORCE crash inside
RunMemoryEfficientAttention instead of gracefully falling back to the
unfused attention path.

Add eligibility guard in ComputeInternal so MEA is skipped for this
configuration, allowing the unfused fallback to handle it.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (16a065d8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Medium findings from code review:
- Use separate v_head_size for V buffer allocation and transpose instead
  of reusing head_size. The ORT_ENFORCE guarantees equality, but this
  matches Flash's defensive style and is correct if the constraint is
  ever relaxed.
- Add safety comment documenting that uninitialized present-buffer
  positions with bool masks are safe because MEA's additive bias drives
  those positions to near-zero softmax weights.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (16a065d8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. Rename TestONNXAttentionGQAFloatMaskDecode to
   TestONNXAttentionMemoryEfficientGQAFloatMaskDecode for searchability
   (all MEA test classes now contain 'MemoryEfficient' or 'MEA')
2. Add present_k/v verification to float mask decode test — now checks
   that concatenated KV buffers match reference, not just output
3. Add comment explaining std=0.2 scaling (keeps fp16 numerically stable)
4. Add TestONNXAttentionGQA4DBNSHMEA — exercises 4D BNSH transpose logic
   through the MEA decode path (use_4d_bnsh=True)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
When bool masks produce variable per-batch past_seq_lens, the concat
kernel leaves some positions unwritten. MEA reads all positions
(unlike Flash which bounds reads via seqlens_k). Zero the present
buffers first so unwritten positions contain 0.0 instead of
potentially NaN-containing uninitialized memory.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (16a065d8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add TestONNXAttentionMHAAsymmetricHeadSize to verify that MEA gracefully
falls back to unfused attention when head_size != v_head_size with past_key
present (decode phase). Without the eligibility guard in ComputeInternal,
this would crash with ORT_ENFORCE in LaunchConcatNewToPastKV.

To support this test, add v_head_size field to AttentionConfig (defaults
to 0 = same as head_size) and propagate it through the ONNX graph builder
and io_binding helpers for V-related shapes (V input, past_value,
present_value, Y output). All existing tests are unaffected since they
don't set v_head_size.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replace `config.v_head_size if config.v_head_size else config.head_size`
with `config.v_head_size or config.head_size` per ruff FURB110.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (b0ebe545) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Add Memory Efficient Attention decode support and tests for ONNX Add Memory Efficient Attention decode support and tests for ONNX Attention Mar 26, 2026
titaiwangms added a commit that referenced this pull request Apr 4, 2026
…r fully-masked batches (#27831)

Description:

### Summary
Fixes three issues in the CUDA ONNX Attention operator and improves spec
compliance:

1. min_bias_align crash on SM<80: The alignment check for Memory
Efficient Attention (MEA) bias used 4*sizeof(T) (bytes), but the check
is against element counts. Fixed to 4 elements, matching CUTLASS
kMinimumAlignment. This prevented valid MEA dispatch on SM<80.

2. MEA NaN for fully-masked batches: When nonpad_kv_seqlen=0, CUTLASS
MEA computes 1/s_prime where s_prime=0, producing NaN. Added
ZeroOutputForFullyMaskedBatches kernel (MEA path only) to zero output
for these batches. Uses int64_t for element count to prevent overflow at
large context lengths.

3. Flash rejects attn_mask for spec compliance: Flash Attention's paged
KV cache produces spec-divergent present_key/present_value layout when
used with attn_mask + past_key. Flash now requires attn_mask == nullptr
— cases with bool mask + past_key fall to the unfused runner which
handles them spec-correctly. Removed ~137 lines of dead code
(ConvertMaskToSeqlensKernel, LaunchConvertMaskToFlashSeqlensK) no longer
needed after this change.

### Known limitation
- GQA + bool attn_mask + past_key currently has no runner (Flash
rejected, unfused doesn't support GQA, MEA blocked by past_key). Tracked
via TODO — PR #27851 (MEA with past_key support) will close this gap.

### Related
- Issue #27885: Flash Attention bool attn_mask semantic divergence (root
cause documented)
- PR #27851: MEA with past_key support (will close GQA gap)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Replaced by #27992

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant