Support softcap and softmax_precision in Attention(CUDA)#27714
Support softcap and softmax_precision in Attention(CUDA)#27714titaiwangms merged 6 commits intomainfrom
Conversation
Enable softcap support for Flash Attention and Memory-Efficient Attention (MEA) paths in the CUDA Attention operator. Enable softmax_precision for all CUDA paths. Changes: - Remove early-reject blocks for softcap and softmax_precision that were blocking all CUDA kernel paths (resolving TODO from PR #27542) - Forward softcap parameter to MEA kernel via p.softcap in both MemoryEfficientAttentionParams construction sites (nonpad_kv_seqlen path and standard path) - Add softcap rejection at the unfused fallback section with a clear error message explaining Flash or MEA is required - Add comment explaining that all CUDA backends (Flash, MEA, Unfused) already accumulate softmax in FP32, so softmax_precision is inherently satisfied Flash Attention already passes parameters.softcap to mha_fwd and mha_fwd_kvcache (3 call sites). The Is_softcap=true kernel variants are already compiled (FLASHATTENTION_DISABLE_SOFTCAP is not defined). MEA supports softcap via runtime branching in kernel_forward.h. No new kernel compilation or template instantiations required. Zero binary size impact. Part of: #27712 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Remove 5 pure-softcap CUDA test filter entries from onnx_backend_test_series_filters.jsonc now that softcap is supported in Flash and MEA paths. Keep 3 qk_matmul+softcap combo filters (need unfused-path softcap, deferred) and 2 GQA fp32 softcap filters (blocked by GQA unfused support, issue #27516). Add softcap=50.0 to Python GQA test case generators (gqa_prompt_test_cases and gqa_past_test_cases) to exercise CUDA softcap through the ONNX Attention GQA path. Add explanatory comments to C++ softcap tests explaining why disable_cuda remains true: head_size != v_head_size blocks Flash, past_key blocks MEA, and unfused doesn't support softcap. Part of: #27712 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Address code review findings: 1. Update stale comments that said softcap and softmax_precision were early-rejected before the cascade. Flash passes softcap natively, MEA forwards via p.softcap, and all backends accumulate softmax in FP32. Updated comments at RunFlashAttention (~L174), RunMemoryEfficientAttention (~L548), and RunUnfusedAttention (~L826). 2. Change unfused fallback softcap guard from `!= 0.0f` to `> 0.0f` for consistency with Flash/MEA kernels and CPU implementation. Negative softcap is mathematically invalid. 3. Add ORT_ENFORCE(softcap_ >= 0.0f) validation in the Attention constructor to catch invalid negative values at model load time. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Remove incorrect claim about 'cuBLAS FP32 compute' in the softmax_precision comment. cublasGemmStridedBatchedHelper for __half can use FP16 compute depending on HalfGemmOptions, so claiming FP32 GEMM is inaccurate. The comment now correctly focuses on what softmax_precision actually controls: the softmax kernel itself uses FP32 arithmetic in all backends. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address PR #27714 review comments: 1. Add ORT_ENFORCE validation for softmax_precision in the Attention constructor. Valid values are TensorProto data types: 0 (not set), 1 (FLOAT), 10 (FLOAT16), 11 (DOUBLE), 16 (BFLOAT16). 2. Fix misleading error message in unfused softcap fallback rejection. Remove inaccurate 'Ensure fp16/bf16 on Ampere+ GPU' — MEA supports FP32 on SM50+ and Flash supports FP16 on SM75+ (Turing). New message mentions dtype, head_size constraints, and past_key compatibility. 3. Add Attention4DSoftmaxPrecisionFloat test case that sets softmax_precision=1 (FLOAT) and verifies output matches the default case. Enabled for both CPU and CUDA since all backends already compute softmax in FP32. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Fix 3 issues found in PR #27714 review: 1. Re-add diff_heads_sizes_softcap CUDA test filters. These tests use head_size != v_head_size which blocks Flash (requires equal sizes), falling to unfused which doesn't support softcap. Added comment explaining the constraint. 2. Remove duplicate unanchored skip pattern for test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda. Only the anchored (^) version is needed. 3. Remove TensorProto::DOUBLE (11) from valid softmax_precision values. CUDA computes softmax in FP32 and cannot satisfy FP64 precision. Valid values are now: 0 (not set), 1 (FLOAT), 10 (FLOAT16), 16 (BFLOAT16). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (f86cdbc3) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
I did not find any blocking correctness issues The change consistently routes softcap through the Flash and MEA paths, keeps unsupported unfused configurations rejected, and narrows the backend-test filters to the combinations that are still genuinely unsupported.
The main residual risk is coverage depth rather than implementation logic: the new tests exercise the newly enabled paths well enough to justify the filter changes, but there is still limited direct path-specific coverage for MEA-only softcap scenarios and invalid softmax_precision values.
Fix #27712
This pull request improves support and validation for the
softcapandsoftmax_precisionattributes in the CUDA Attention operator, updates kernel eligibility and fallback logic, and enhances test coverage for these features. The changes ensure that only valid values are accepted, propagate new parameters to eligible kernels, and clarify backend capabilities in code comments and tests.CUDA Attention operator improvements:
softcapis non-negative and thatsoftmax_precisionis one of the supported TensorProto types (0, 1, 10, or 16).softcapis now supported natively in Flash and Memory Efficient Attention (MEA) kernels, and thatsoftmax_precisionis inherently satisfied (always computed in FP32 on CUDA). [1] [2] [3]softcapparameter to the MEA kernel invocation to enable native support. [1] [2]softcapwith a clear error message, whilesoftmax_precisionis always considered satisfied. [1] [2]Testing improvements:
softmax_precision=1(FLOAT) produces identical results to the default, since all CUDA backends compute softmax in FP32.softcapvalues, increasing coverage of this feature. [1] [2]