Skip to content

Eliminate Legacy MHA Unfused path from ONNX Attention; unify on 3-tier dispatch with causal alignment fix#27992

Merged
titaiwangms merged 6 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2
May 4, 2026
Merged

Eliminate Legacy MHA Unfused path from ONNX Attention; unify on 3-tier dispatch with causal alignment fix#27992
titaiwangms merged 6 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Apr 6, 2026

Motivation

Eliminate the legacy MHA Unfused path (QkvToContext in attention_impl.cu) from the ONNX standard Attention op, simplifying the CUDA dispatch to a clean 3-tier cascade.

Design

Flash Attention → Memory-Efficient Attention (MEA) → Unified Unfused Attention
  • Flash: Handles fp16/bf16 with head_size ≤ 256, no explicit attn_mask. Fastest path.
  • MEA (CUTLASS): Handles cases Flash cannot (explicit masks, softcap+mask combos). Requires head_size % 8 == 0.
  • Unified Unfused: Fallback for everything else — fp32, small heads, H≠H_v, output_qk. Handles both MHA and GQA via FP32 QK accumulation.

The legacy RunUnfusedAttention wrapper (which called contrib ops QkvToContext) is deleted. The contrib MHA op is unaffected.

Key Behavior Changes

  • Unified unfused kernel replaces separate GQA-only and MHA-only unfused paths
  • Causal alignment: lower-right when past_key is present, upper-left otherwise (per ONNX spec)
  • H≠H_v + past KV now supported (separate K/V concat calls)
  • output_qk (mode 0) supported in unified kernel via ScaledCopyQkKernel
  • 29 ONNX backend test filters removed — tests now pass natively

Testing

All existing tests pass (40 C++ attention tests, 215 Python parametrized cases) plus new coverage for causal alignment on CPU EP and softcap ordering verification.

Closes #27880. Related: #27516, #28198.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.

Changes:

  • Add softcap plumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax.
  • Implement MEA decode by concatenating past+new KV into the present buffer (via LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs.
  • Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/llm/attention.cc Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging.
onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax.
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Adds softcap to AttentionParameters.
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var).
onnxruntime/test/python/transformers/test_onnx_attention/common.py Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests.
onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test.
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment.
onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Updates exclusions/comments for attention backend tests based on new softcap/decode behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request Apr 6, 2026
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Add Memory Efficient Attention decode support and tests for ONNX ONNX Attention CUDA: Add MEA decode support and unfused softcap Apr 6, 2026
Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no but check whether we are not breaking graph at any point. For example, #27484

Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: Add MEA decode support and unfused softcap ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering Apr 7, 2026
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix Apr 7, 2026
@titaiwangms titaiwangms requested a review from Copilot April 7, 2026 23:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
@justinchuby
Copy link
Copy Markdown
Contributor

Will this fix #28196 and #28195?

@titaiwangms titaiwangms requested a review from tianleiwu April 23, 2026 17:44
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch 2 times, most recently from c74bbaa to e5f74e9 Compare April 23, 2026 21:17
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request Apr 23, 2026
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch from e5f74e9 to 5afc98e Compare April 23, 2026 21:20
tianleiwu
tianleiwu previously approved these changes Apr 23, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APPROVE

The implementation is sound, addresses the prior output_qk ordering concern (resolved that thread), and provides comprehensive test coverage across MEA decode, unfused softcap, spec-correct mask→softcap ordering, and the NaN fix. No high-severity issues found.

Highlights

  • output_qk copy correctly moved before bias/softcap mutations — prior CHANGES_REQUESTED concern resolved.
  • MEA decode path has well-managed scratch buffer lifetimes; kv_is_bsnh tracking correctly propagates through GQA expansion.
  • mask_filter_value cap (-1e+30f) is correctly scoped to MEA only and the math is verified.
  • Test coverage is thorough: C++ and Python, MEA/unfused, prompt/decode, fp16/bf16/fp32, various mask dims.

Suggestions (non-blocking)

  1. Causal mask ordering with softcap: In the has_softcap && has_bias branch, ComputeSoftmax applies is_unidirectional causal masking after softcap. Strict ONNX spec folds causal into the mask before softcap. This matches the CPU reference and tests pass, but a TODO comment would help track future spec alignment.
  2. Defensive assertion: The has_softcap && has_bias branch bypasses use_raw_attention_mask handling. Safe today (ONNX domain converts all masks to data.attention_bias before unfused), but adding assert(!use_raw_attention_mask) guards against future contrib softcap use where mask_index could coexist.
  3. Bool mask test filters: test_attention_4d_attn_mask_bool_cuda and _4d_cuda are still excluded with "may work now" TODOs. Consider removing or filing a follow-up issue.

@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch 4 times, most recently from a8f10de to 0a85503 Compare April 30, 2026 21:42
@titaiwangms titaiwangms requested a review from Copilot April 30, 2026 21:52
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Outdated
Comment thread onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Outdated
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Multi-agent review summary

I ran this PR through four independent reviewers (gpt-5.3-codex, gpt-5.5, claude-sonnet-4.6, claude-opus-4.7-high). There is strong consensus on one Critical bug plus several Major correctness/structural issues. Findings consolidated below; nits omitted unless they hide a bug.


🔴 Critical

C-1. Unified unfused kernel silently produces lower-right when caller asks for upper-left under TensorScatter / external-cache + causal

  • File: onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu:120
  • Caller: onnxruntime/core/providers/cuda/llm/attention.cc (sets p.past_kv_length = parameters.past_sequence_length)
const int past = (seqlens_k != nullptr) ? (kv_end - q_sequence_length) : past_kv_length;

When nonpad_kv_seqlen != nullptr, the kernel ignores the new past_kv_length field that dispatch carefully sets to 0 for the no-past TensorScatter case. It always derives past = kv_end − S_q, i.e. lower-right.

Concrete failure (matrix row #4):

  • is_causal=1, past_key=nullptr, nonpad_kv_seqlen=[10], S_q=3, kv_total=10
  • Spec/MEA path: q[0] attends to kv[0] only (upper-left)
  • Unfused kernel: past = 10 − 3 = 7, so q[0] attends to kv[0..7] (lower-right) — different numerical output for the same op + inputs depending on which dispatch tier wins

Reachable today via: output_qk requested, fp32 GQA, head_size > 1024, or any case that forces the unified unfused tier.

Three of four reviewers flagged this independently. The PR's response (deleting 20 TensorScatter is_causal=1 Python tests as "spec-invalid") hides the bug rather than fixing it — see M-1.

Suggested fix: use past_kv_length for the causal cutoff unconditionally, while still using per-batch kv_end only as the upper bound. Or add an explicit causal_alignment enum to UnfusedAttentionParams mirroring the MEA causal_from_top_left plumbing.


🟠 Major

M-1. TensorScatter is_causal=1 coverage was removed instead of fixing the kernel

  • File: onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py:462,471 (and ~20 affected cases)

The ONNX spec does not declare is_causal=1 with TensorScatter + S_q != S_kv invalid; it produces a well-defined upper-left mask (just "useless" for decode). Dropping the tests masks C-1. CPU reference still applies bottom-right alignment in the same file (line ~120), so even the reference is internally inconsistent.

Fix: restore the cases after fixing C-1, OR explicitly reject the combination in dispatch with NOT_IMPLEMENTED. Silent wrong-numbers is the worst option.

M-2. H != H_v concat path relies on duplicate writes to the same output buffer

  • File: onnxruntime/core/providers/cuda/llm/attention.cc:1083–1102
LaunchConcatNewToPastKV(... pk, pk, nk, nk, out_k, out_k, ...);  // K
LaunchConcatNewToPastKV(... pv, pv, nv, nv, out_v, out_v, ...);  // V

Each call's K-block and V-block both write to the same destination. Correct only as a "benign data race" (same value, same address). Two issues:

  1. cuda-memcheck/racecheck and static analyzers will flag this.
  2. The contract is invisible from the call site — any future refactor of LaunchConcatNewToPastKV (atomicAdd, accumulation, paged-cache layout, RoPE variant) silently corrupts data here, with no ORT_ENFORCE to catch it.

Fix: add a K-only / V-only mode to LaunchConcatNewToPastKV, or use a small dedicated copy kernel for the H != H_v case.

M-3. ScaledCopyQkKernel leaks K-cache padding into output_qk

  • File: onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu (ScaledCopyQkKernel + caller)

For TensorScatter / external cache, positions [seqlens_k[b], total_kv) may contain stale or uninitialized data. ScaledCopyQkKernel copies all total_kv positions, propagating that into the user-visible output_qk. CPU/MEA reference behavior likely zeros or -infs those positions → cross-EP numerical divergence.

Fix: mask output_qk[b,*,*,i] to 0 (or -inf) for i ≥ seqlens_k[b] after copy, OR explicitly document that contents past seqlens_k are unspecified.

M-4. Possible OOB read in attn_mask handling for [B,H,S_q,kv_seq] mask + past_key

  • File: onnxruntime/core/providers/cuda/llm/attention.cc (mask → bias path replacing the deleted LaunchAddBiasInPlace of legacy RunUnfusedAttention)

The legacy path explicitly composed nonpad-bias + attn_mask into a [B, q, total_seq] buffer. The new path passes user attn_mask as attn_bias to softmax, which indexes via (...) * total_kv + i with i ∈ [0, total_kv). For a 4D mask of shape [B,H,S_q,kv_seq] (kv_seq, not total_seq) with past_key present, the bias buffer is sized B*H*S_q*kv_seq but the kernel reads up to total_kv columns → OOB for i ≥ kv_seq. Please confirm validation upstream rejects this shape; if not, this is a real bug.

M-5. Workspace blow-up: unified unfused always allocates two FP32-sized score matrices

  • File: onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu:391
const size_t qk_bytes = AlignTo(SafeInt<size_t>(elems) * sizeof(float), kAlign);
const size_t softmax_bytes = AlignTo(SafeInt<size_t>(elems) * sizeof(float), kAlign);

For B=1, H=32, S_q=S_kv=4096, that's ~4 GiB scratch before output/present tensors. The deleted legacy MHA path did not universally force this for plain MHA. After this PR every fp16/bf16 MHA fallback (no Flash, no MEA — old GPUs, head sizes MEA doesn't support, output_qk, builds without MEA) pays this cost.

Fix: keep a memory-threshold fallback or fail with a clear diagnostic before allocating multi-GB scratch. At minimum, document the regression envelope. No benchmark evidence accompanied the consolidation.

M-6. Eligibility predicate for past_key + MEA is duplicated in two places

  • File: onnxruntime/core/providers/cuda/llm/attention.cc:1346 (eligibility) and :610–611 (ORT_RETURN_IF_NOT inside RunMemoryEfficientAttention)

Same condition (head_size == v_head_size for MEA decode) lives in both. If the eligibility predicate ever drifts, the inner check returns an error to the user instead of falling back to unfused. Extract a helper.


🟡 Readability

R-1. TestONNXAttentionPaddingMaskMEAGQA docstring still says SKIPPED: and references a future PR — but the class is now active and is exactly that PR

  • File: onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py:2743–2749

A reader will hunt for a @unittest.skip decorator that no longer exists.

R-2. The 6-line "ONNX spec: is_causal means upper-left…" comment block is duplicated verbatim in two MEA branchesattention.cc:1018-1026 and attention.cc:1071-1076. Extract into a single named local before the branching; comments will drift otherwise.

R-3. causal_cross_no_past is defined inside #if USE_FLASH_ATTENTIONattention.cc:1544–1547. The predicate is dispatch-level semantics; a reader auditing MEA/Unfused eligibility won't see it. Move outside the preprocessor guard with a comment that only Flash needs the guard.

R-4. past_kv_length is computed by two different formulas in two callersgroup_query_attention_impl.cu:382 (total_seq − seq) vs attention.cc:1488 (past_sequence_length). Same value, different reading effort. Pick one form.

R-5. New causal alignment tests are CUDA-only (disable_cpu=true) — the canonical spec example for causal cross-attention should pin cross-EP numerical agreement. Run on CPU too so the spec contract (not the implementation) is the test.


✅ Things that look right

  • Flash-blocking guard for causal-cross-no-past correctly uses parameters.total_sequence_length and is appropriately tight (does not over-block square or non-causal cases).
  • causal_from_top_left is only consulted when params.causal == true in fmha_launch_template.h, so the non-causal MEA path is unperturbed and contrib MHA/GQA callers compile unchanged with safe defaults.
  • ScaledCopyQkKernel is launched after the FP32 QK GEMM and before softcap/mask/softmax — correct kQK semantics, and dtype templating covers fp16/bf16/fp32.
  • kCutlassSafeMaskFilterValue constant with the −FLT_MAX × kLog2e overflow derivation is exemplary numeric-invariant documentation.
  • Attention4DSoftCapOutputQkRawLogits and the poison-value tests are sharp regression guards that pin operation ordering (QK → softcap → mask → softmax per Softcap in Attention op onnx/onnx#7865).
  • Deleting the legacy RunUnfusedAttention (~228 lines) collapses 4 dispatch paths to 3 with a clearly stated cascade — material reduction in maintenance surface.
  • The new SKILL.md is genuinely additive (alignment table, ScopedEnvironmentVariables recipe, per-kernel guidance) rather than a comment transcript.

Recommended order of operations

  1. Fix C-1 in unfused_attention.cu:120 (use past_kv_length unconditionally for the causal cutoff).
  2. Restore the TensorScatter is_causal=1 Python coverage with corrected upper-left expectations (validates the C-1 fix).
  3. Address M-3 (output_qk leakage past seqlens_k) — also a correctness/portability issue.
  4. Decide on M-4 — confirm validation rejects the shape, or fix the indexing.
  5. Replace M-2 with an explicit K-only/V-only concat mode.
  6. Address M-5 with a fallback heuristic or a clear failure path.
  7. Readability passes (R-1 to R-5) can land alongside the fix or as a follow-up.

The single must-fix before merge is C-1 — it produces silent wrong numbers on a path that the PR explicitly claims to make correct.

— Synthesized from reviews by gpt-5.3-codex, gpt-5.5, claude-sonnet-4.6, claude-opus-4.7-high.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc
tianleiwu
tianleiwu previously approved these changes May 1, 2026
titaiwangms and others added 6 commits May 4, 2026 16:56
…r dispatch

Refactor ONNX Attention CUDA dispatch to a clean 3-tier cascade:
Flash → Memory-Efficient (CUTLASS) → Unified Unfused.

- Remove RunUnfusedAttention (Legacy MHA Unfused path)
- Rename gqa_unfused_attention.cu → unfused_attention.cu as shared kernel
- Fix causal alignment: upper-left (no past) vs lower-right (with past) per ONNX spec
- Add causal_from_top_left flag to MEA (CUTLASS supports both alignments)
- Add Flash guard: block causal + cross-attention + no past (no upper-left support)
- Fix sliding-window per-batch past derivation when seqlens_k varies
- Cap ScaledCopyQkKernel grid with grid-stride loop
- Fix v_head_size reshape in test infrastructure for H!=H_v
- Fully qualify kCutlassSafeMaskFilterValue namespace
- Remove dead mask functions from attention_mask_impl
- Add 4 C++ causal alignment tests
- Fix Python TensorScatter tests (is_causal=0 for decode per spec)
- Remove test_attention_4d_gqa_with_past_and_present_fp16_cuda filter (now passes)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…aming

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…icate, readability fixes

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…LEMENTED

Per ONNX spec, is_causal without past_key means upper-left alignment which
is meaningless for decode (q[0] only sees kv[0]). Reject explicitly rather
than producing silent wrong results through the unfused kernel path.

Add negative test verifying the guard fires correctly.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 4 causal alignment tests (UpperLeft, UpperLeftSmallHead,
DecodeWithPastLowerRight, SquareNoPast) were CUDA-only. CPU EP
handles them correctly — enable CPU execution.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch from c679669 to c8147e4 Compare May 4, 2026 17:05
@titaiwangms titaiwangms requested a review from tianleiwu May 4, 2026 17:33
@titaiwangms titaiwangms merged commit 4ca6b22 into microsoft:main May 4, 2026
88 of 89 checks passed
@titaiwangms titaiwangms deleted the feature/mea-decode-support-v2 branch May 4, 2026 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ONNX Attention CUDA: Coverage Gaps in Runner Fallback Paths

4 participants