Skip to content

Support softcap and softmax_precision in Attention(CUDA)#27714

Merged
titaiwangms merged 6 commits intomainfrom
titaiwang/support_softcap_softmax_precision
Mar 19, 2026
Merged

Support softcap and softmax_precision in Attention(CUDA)#27714
titaiwangms merged 6 commits intomainfrom
titaiwang/support_softcap_softmax_precision

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Mar 17, 2026

Fix #27712

This pull request improves support and validation for the softcap and softmax_precision attributes in the CUDA Attention operator, updates kernel eligibility and fallback logic, and enhances test coverage for these features. The changes ensure that only valid values are accepted, propagate new parameters to eligible kernels, and clarify backend capabilities in code comments and tests.

CUDA Attention operator improvements:

  • Added validation to enforce that softcap is non-negative and that softmax_precision is one of the supported TensorProto types (0, 1, 10, or 16).
  • Updated code comments and eligibility checks to clarify that softcap is now supported natively in Flash and Memory Efficient Attention (MEA) kernels, and that softmax_precision is inherently satisfied (always computed in FP32 on CUDA). [1] [2] [3]
  • Propagated the softcap parameter to the MEA kernel invocation to enable native support. [1] [2]
  • Modified fallback and rejection logic: unfused attention now explicitly rejects softcap with a clear error message, while softmax_precision is always considered satisfied. [1] [2]

Testing improvements:

  • Added a new test to verify that softmax_precision=1 (FLOAT) produces identical results to the default, since all CUDA backends compute softmax in FP32.
  • Clarified in existing softcap-related tests that certain configurations are not supported by CUDA unfused attention and require Flash or MEA; updated test comments for clarity. [1] [2]
  • Expanded Python test cases for GQA (grouped-query attention) to include nonzero softcap values, increasing coverage of this feature. [1] [2]

titaiwangms and others added 3 commits March 17, 2026 20:59
Enable softcap support for Flash Attention and Memory-Efficient Attention
(MEA) paths in the CUDA Attention operator. Enable softmax_precision for
all CUDA paths.

Changes:
- Remove early-reject blocks for softcap and softmax_precision that were
  blocking all CUDA kernel paths (resolving TODO from PR #27542)
- Forward softcap parameter to MEA kernel via p.softcap in both
  MemoryEfficientAttentionParams construction sites (nonpad_kv_seqlen
  path and standard path)
- Add softcap rejection at the unfused fallback section with a clear
  error message explaining Flash or MEA is required
- Add comment explaining that all CUDA backends (Flash, MEA, Unfused)
  already accumulate softmax in FP32, so softmax_precision is inherently
  satisfied

Flash Attention already passes parameters.softcap to mha_fwd and
mha_fwd_kvcache (3 call sites). The Is_softcap=true kernel variants
are already compiled (FLASHATTENTION_DISABLE_SOFTCAP is not defined).
MEA supports softcap via runtime branching in kernel_forward.h.

No new kernel compilation or template instantiations required.
Zero binary size impact.

Part of: #27712

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Remove 5 pure-softcap CUDA test filter entries from
onnx_backend_test_series_filters.jsonc now that softcap is supported
in Flash and MEA paths. Keep 3 qk_matmul+softcap combo filters
(need unfused-path softcap, deferred) and 2 GQA fp32 softcap filters
(blocked by GQA unfused support, issue #27516).

Add softcap=50.0 to Python GQA test case generators
(gqa_prompt_test_cases and gqa_past_test_cases) to exercise CUDA
softcap through the ONNX Attention GQA path.

Add explanatory comments to C++ softcap tests explaining why
disable_cuda remains true: head_size != v_head_size blocks Flash,
past_key blocks MEA, and unfused doesn't support softcap.

Part of: #27712

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Address code review findings:

1. Update stale comments that said softcap and softmax_precision were
   early-rejected before the cascade. Flash passes softcap natively,
   MEA forwards via p.softcap, and all backends accumulate softmax in
   FP32. Updated comments at RunFlashAttention (~L174),
   RunMemoryEfficientAttention (~L548), and RunUnfusedAttention (~L826).

2. Change unfused fallback softcap guard from `!= 0.0f` to `> 0.0f`
   for consistency with Flash/MEA kernels and CPU implementation.
   Negative softcap is mathematically invalid.

3. Add ORT_ENFORCE(softcap_ >= 0.0f) validation in the Attention
   constructor to catch invalid negative values at model load time.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms requested a review from Copilot March 17, 2026 21:19
@titaiwangms titaiwangms marked this pull request as ready for review March 17, 2026 21:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Remove incorrect claim about 'cuBLAS FP32 compute' in the
softmax_precision comment. cublasGemmStridedBatchedHelper for __half
can use FP16 compute depending on HalfGemmOptions, so claiming FP32
GEMM is inaccurate. The comment now correctly focuses on what
softmax_precision actually controls: the softmax kernel itself uses
FP32 arithmetic in all backends.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Address PR #27714 review comments:

1. Add ORT_ENFORCE validation for softmax_precision in the Attention
   constructor. Valid values are TensorProto data types: 0 (not set),
   1 (FLOAT), 10 (FLOAT16), 11 (DOUBLE), 16 (BFLOAT16).

2. Fix misleading error message in unfused softcap fallback rejection.
   Remove inaccurate 'Ensure fp16/bf16 on Ampere+ GPU' — MEA supports
   FP32 on SM50+ and Flash supports FP16 on SM75+ (Turing). New message
   mentions dtype, head_size constraints, and past_key compatibility.

3. Add Attention4DSoftmaxPrecisionFloat test case that sets
   softmax_precision=1 (FLOAT) and verifies output matches the default
   case. Enabled for both CPU and CUDA since all backends already
   compute softmax in FP32.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Outdated
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Fix 3 issues found in PR #27714 review:

1. Re-add diff_heads_sizes_softcap CUDA test filters. These tests use
   head_size != v_head_size which blocks Flash (requires equal sizes),
   falling to unfused which doesn't support softcap. Added comment
   explaining the constraint.

2. Remove duplicate unanchored skip pattern for
   test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda.
   Only the anchored (^) version is needed.

3. Remove TensorProto::DOUBLE (11) from valid softmax_precision values.
   CUDA computes softmax in FP32 and cannot satisfy FP64 precision.
   Valid values are now: 0 (not set), 1 (FLOAT), 10 (FLOAT16),
   16 (BFLOAT16).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Mar 19, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not find any blocking correctness issues The change consistently routes softcap through the Flash and MEA paths, keeps unsupported unfused configurations rejected, and narrows the backend-test filters to the combinations that are still genuinely unsupported.

The main residual risk is coverage depth rather than implementation logic: the new tests exercise the newly enabled paths well enough to justify the filter changes, but there is still limited direct path-specific coverage for MEA-only softcap scenarios and invalid softmax_precision values.

@titaiwangms titaiwangms merged commit 93d31cf into main Mar 19, 2026
91 checks passed
@titaiwangms titaiwangms deleted the titaiwang/support_softcap_softmax_precision branch March 19, 2026 23:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support softcap, softmax_precision in Attention(CUDA)

3 participants