diff --git a/.agents/skills/cuda-attention-kernel-patterns/SKILL.md b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md new file mode 100644 index 0000000000000..5325a1bf22bdc --- /dev/null +++ b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md @@ -0,0 +1,237 @@ +--- +name: cuda-attention-kernel-patterns +description: Patterns and pitfalls for the ONNX domain Attention operator (opset 23/24) CUDA implementation. Use when modifying the dispatch cascade in core/providers/cuda/llm/attention.cc, writing mask/bias CUDA kernels, debugging attention test routing, or adding features to the ONNX Attention op. NOT for contrib domain MultiHeadAttention/GroupQueryAttention. +--- + +# ONNX Domain Attention (Opset 23/24) CUDA Patterns + +Reusable knowledge from ONNX Attention CUDA development in ORT. + +> **Scope**: This skill covers the **ONNX domain** `Attention` operator (opset 23/24) +> implemented at `core/providers/cuda/llm/attention.cc`. This is **separate from** the +> contrib domain `MultiHeadAttention` / `GroupQueryAttention` at `contrib_ops/cuda/bert/`. +> They share some underlying kernels (CUTLASS FMHA, Flash Attention) and infrastructure +> (`attention_softmax.h`) but have **different dispatch logic, parameter structs, and eligibility checks**. +> +> - **Shared infrastructure**: CUTLASS FMHA kernel, Flash kernel, unified unfused kernel +> (`unfused_attention.cu`), `attention_softmax.h`, `attention_impl.cu` (contrib only) +> - **ONNX-specific**: Dispatch cascade in `attention.cc`, `ConvertAttnMaskToBias`, +> `mask_filter_value` cap, parameter bridge to contrib structs, `attention_mask_impl.cu` +> - **Contrib-specific**: Own dispatch in contrib MHA/GQA ops, uses `contrib::AttentionParameters` +> directly, has XQA kernel, past-present buffer sharing + +## 1. Runner Dispatch Cascade + +CUDA attention dispatches in priority order: **Flash → MEA (Memory Efficient) → Unified Unfused Attention**. + +``` +// onnxruntime/core/providers/cuda/llm/attention.cc — ComputeInternal() +Flash eligible? → RunFlashAttention() + ↓ no +MEA eligible? → RunMemoryEfficientAttention() + ↓ no +Unified Unfused → RunUnfusedAttention() + (handles both MHA and GQA via reshape-Q trick) +``` + +**Flash eligibility**: fp16/bf16 only, SM≥8.0 (Ampere+), `head_size == v_head_size`, `head_size <= 256`, no `output_qk`, `attn_mask == nullptr`. Uses `mha_fwd` / `mha_fwd_kvcache`. + +**MEA eligibility**: SM50+/53+/80+ by dtype, `head_size <= 1024` and divisible by 8, no `output_qk`. Decode requires `head_size == v_head_size` (for `LaunchConcatNewToPastKV`). Bias stride must satisfy `total_sequence_length % 4 == 0`. GQA with FP32 is excluded (LaunchUngroup only has fp16/bf16 instantiations). Supports `softcap + attn_mask` — CUTLASS applies softcap before bias in kernel tiles, matching ONNX spec ordering (onnx/onnx#7865). + +**Unified Unfused Attention**: Always available as the final fallback. Handles both MHA (`num_heads == kv_num_heads`, group=1) and GQA (`num_heads != kv_num_heads`, group>1) via a reshape-Q trick with stride-based cuBLAS batched GEMM (no K/V head replication). Uses FP32 QK scratch for precision. Supports all features: +- softcap + attn_mask (spec-correct ordering) +- output_qk (kQK mode: copies raw QK before softcap/mask mutations) +- past_key + past_value with `head_size != v_head_size` (separate K/V concat) +- causal masking, nonpad_kv_seqlen, all dtypes (fp16/bf16/fp32) + +## 2. CUTLASS kLog2e Overflow + +CUTLASS `iterative_softmax` multiplies all attention scores by `kLog2e ≈ 1.4427` internally (for `exp2f` instead of `expf`). For float/bf16: + +``` +mask_filter_value = std::numeric_limits::lowest() ≈ -3.40e+38 +-3.40e+38 × 1.4427 ≈ -4.91e+38 → overflows fp32 → -inf +``` + +When all values become `-inf`, CUTLASS's special-case path produces `s_prime=0` → `1/s_prime=inf` → `0 × inf = NaN`. + +**Fix**: Cap `mask_filter_value` to `-1.0e+30f` in `ConvertAttnMaskToBias`. This value is safe: `1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX`, and `exp(-1e30) ≈ 0` (effectively masked). + +**fp16 is NOT affected**: `lowest() = -65504`, and `-65504 × 1.4427 ≈ -94500` stays within fp32 range. + +This cap is ONLY applied in MEA paths. The unfused path uses `lowest()` directly (its softmax subtracts max first, avoiding overflow). + +**Subtlety**: When bias is present (`kSupportsBias=true`), CUTLASS pre-applies `p.scale` to QK (line 858) and uses `scaling=1.0f` in the softmax loop (line 981). So the full `kLog2e` multiplier hits the bias-dominated values — the overflow is head_size-independent. Without bias, `scaling = p.scale * kLog2e = kLog2e/sqrt(head_size)`, which is much smaller. + +## 3. Bias Alignment + +CUTLASS FMHA requires the attention bias row stride to satisfy minimum alignment. The bias has shape `[B, H, S, T]` where `T = total_sequence_length` is the row stride. + +```cpp +constexpr int min_bias_align = 4; // elements, not bytes +if (parameters.total_sequence_length % min_bias_align != 0) { + mea_eligible = false; // fall through to unfused +} +``` + +**Impact on tests**: If a test uses `total_sequence_length` not divisible by 4 (e.g., past=5 + new=6 = 11), MEA is rejected and unfused handles it. To test MEA with bias, ensure `total_sequence_length % 4 == 0`. + +## 4. Softcap Ordering + +ONNX spec ordering (onnx/onnx#7865): `QK → scale → softcap → add mask/bias → softmax` + +- **MEA (CUTLASS)**: Fuses softcap before bias in kernel tile loop (`kernel_forward.h`). Matches spec ordering. +- **Flash**: Handles softcap natively in `mha_fwd`/`mha_fwd_kvcache` but rejects `attn_mask`, so ordering with mask is moot. +- **Unfused**: Handles spec-correct ordering in the fused softmax kernel: `QK → scale → softcap → add bias → softmax`. + +All three paths apply softcap BEFORE mask/bias. If softcap were applied after masking, `tanh(-inf/sc) = -sc` (finite), leaking probability to masked positions. + +The unfused path does: `QK → scale → softcap → add bias → softmax` (all fused in `UnfusedSoftmaxKernel`). + +## 5. Grid-Stride Loops for CUDA Kernels + +Always cap grid size to prevent exceeding `gridDim.x` limits, and use grid-stride loops for large workloads: + +```cpp +constexpr int64_t kMaxGridDimX = 65535; +int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); +int64_t blocks = (total + threads - 1) / threads; +unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + +MyKernel<<>>(...); + +// Inside the kernel: +for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += static_cast(gridDim.x) * blockDim.x) { + // work +} +``` + +**Never** cast `int64_t` block count directly to `unsigned int` without capping — it silently truncates. + +Always call `CUDA_CALL(cudaGetLastError())` after kernel launches in standalone helper functions. This is the established pattern in the file (see `ConcatPastToPresent`, `PastPresentBufferShare`). + +## 6. Fully-Masked Batches + +All-false bool masks or `seqlens_k=0` produce NaN in CUTLASS MEA. + +**Additive-bias path** (bool mask converted to bias): Fixed by capping `mask_filter_value` to `-1e+30f` (see section 2). CUTLASS then naturally computes uniform softmax → mean(V). + +**Nonpad path** (`seqlens_k=0`): CUTLASS skips all K/V positions → `s_prime=0` → NaN. Fixed by `ZeroOutputForFullyMaskedBatches` kernel which zeros output for batches where `seqlens_k[b] == 0`. Note: this produces zeros, not mean(V) — a cross-EP consistency TODO exists. + +**CPU/Unfused behavior**: `mask_filter_value = lowest()` (not `-inf`). All masked values are equal → `softmax(equal) = 1/N` → output = mean(V). This is the spec reference. + +## 7. Test Runner Targeting + +Use `ScopedEnvironmentVariables` to force specific CUDA runners: + +```cpp +// Force MEA (disable Flash) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, +}); + +// Force Unfused (disable both Flash and MEA) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, + {"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", "1"}, +}); +``` + +**Always verify which runner a test actually hits.** A test designed for MEA may silently fall to unfused if: +- `total_sequence_length % 4 != 0` (bias alignment) +- `head_size != v_head_size` (decode path) +- fp32 dtype with GQA (LaunchUngroup fp16/bf16 only) +- fp32 dtype on SM < 80 + +Enable verbose logging to confirm: `LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using ..."`. + +## 8. Cross-EP Consistency + +CPU is the spec reference implementation. CUDA outputs should match CPU for all valid inputs. + +- CPU uses `mask_filter_value = std::numeric_limits::lowest()` (finite, not `-inf`) +- CPU softmax: subtract-max-first → works correctly with extreme finite values +- CPU handles fully-masked batches naturally (uniform softmax → mean(V)) + +Run tests with `disable_cpu=false` to always validate against CPU. The C++ test framework (`RunTest4D`) supports `disable_cpu`, `disable_cuda`, `disable_dml` flags. + +## 9. File Locations + +### ONNX Domain (this op's code) + +| File | Purpose | +|------|---------| +| `core/providers/cuda/llm/attention.cc` | ONNX Attention CUDA dispatch: Flash/MEA/Unfused cascade, `ConvertAttnMaskToBias`, parameter setup | +| `core/providers/cuda/llm/attention_mask_impl.cu` | ONNX-specific mask/bias CUDA kernels: bool→bias, nonpad→seqlens_k, ZeroOutput, bias composition | +| `core/providers/cuda/llm/attention_mask_impl.h` | Declarations for ONNX mask/bias kernels | +| `core/providers/cpu/llm/attention.cc` | CPU reference implementation (ONNX domain) | +| `core/providers/cpu/llm/attention_helper.h` | ONNX parameter validation and shape computation | +| `test/providers/cpu/llm/attention_op_test.cc` | C++ attention tests (all EPs) | +| `test/python/transformers/test_onnx_attention/test_mha.py` | Python parity tests | +| `test/python/transformers/test_onnx_attention/common.py` | Python test utilities and reference `attention_ref()` | + +### Shared Infrastructure (used by both ONNX and contrib ops) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/unfused_attention.cu` | Unified unfused attention: QK GEMM (FP32), fused softmax kernel (scale+softcap+bias+causal), V GEMM. Handles MHA and GQA. | +| `contrib_ops/cuda/bert/unfused_attention.h` | `UnfusedAttentionParams`, `LaunchUnfusedAttention`, workspace size | +| `contrib_ops/cuda/bert/attention_impl.cu` | Legacy unfused `QkvToContext` (contrib MHA only). Also `ApplySoftcap`, `ConcatPastToPresent` | +| `contrib_ops/cuda/bert/attention_softmax.h` | CUDA softmax kernels (`ComputeSoftmax`, `ComputeSoftmaxWithRawMask`) — used by legacy contrib path | +| `contrib_ops/cuda/bert/cutlass_fmha/` | CUTLASS FMHA (Memory Efficient Attention) kernels | +| `contrib_ops/cuda/bert/flash_attention/` | Flash Attention kernels | + +### Contrib Domain (separate ops, NOT covered by this skill) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/multihead_attention.cu` | Contrib `MultiHeadAttention` — own dispatch, uses `contrib::AttentionParameters` directly | +| `contrib_ops/cuda/bert/group_query_attention.cu` | Contrib `GroupQueryAttention` — has XQA kernel, past-present buffer sharing | + +## 10. Parameter Bridge (ONNX → Contrib) + +The ONNX Attention op uses `attention_helper::AttentionParameters` (in `core/providers/cpu/llm/attention_parameters.h`). The unified unfused kernel (`LaunchUnfusedAttention`) uses its own `UnfusedAttentionParams` struct populated directly from ONNX parameters in `RunUnfusedAttention`. + +The contrib `QkvToContext` function (used by contrib MHA, NOT by ONNX Attention) uses `contrib::AttentionParameters`. ONNX Attention does **not** bridge to `contrib::AttentionParameters` — it routes through the unified unfused kernel instead. + +## 11. Causal Alignment + +The ONNX spec defines two causal alignment modes based on where query positions sit in the full attention matrix: + +- **Upper-left**: `q_i` attends to `kv[0..i]`. Query positions start at 0 in the full matrix. +- **Lower-right**: `q_i` attends to `kv[kv_len - q_len + i..kv_len - 1]`. Query positions are at the end. + +**ONNX spec rule**: `is_causal=1` always means upper-left in the full matrix. When `past_key` provides context, `past_sequence_length` shifts the query start position forward — the resulting `[S_q × total_kv]` sub-matrix effectively has lower-right alignment. + +### Per-kernel behavior + +| Kernel | Alignment | Mechanism | +|--------|-----------|-----------| +| **Flash** | Lower-right only | `is_causal` flag → `seqlen_k - seqlen_q` offset in kernel. No top-left option. | +| **MEA (CUTLASS)** | Both | `causal_from_top_left` flag in `MemoryEfficientAttentionParams`. `true` → `CausalFromTopLeft` (offset=0). `false` → `CausalFromBottomRight` (offset = num_keys - num_queries). | +| **Unfused** | Both | `past_kv_length` param. `0` → upper-left. `total_kv - S_q` → lower-right. | + +### Dispatch logic in attention.cc + +```cpp +// Flash cannot do upper-left → guarded by causal_cross_no_past +bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + +// Flash: skip when causal_cross_no_past (no top-left support) +// MEA: NOT skipped — handles it via causal_from_top_left = (past_sequence_length == 0) +// Unfused: always correct via past_kv_length = parameters.past_sequence_length +``` + +### When S_q == S_kv + +Upper-left and lower-right produce **identical** results when `S_q == S_kv` (the offset is 0 either way). The alignment distinction only matters for cross-attention shapes (`S_q != S_kv`). + +### TensorScatter decode (opset 24 external KV cache) + +TensorScatter manages KV cache externally — `past_key` is nullptr but K/V already contain the full sequence. Per the ONNX spec, `is_causal` with `S_q != S_kv` and no `past_key` means upper-left (q[0] sees only kv[0]), which is **not meaningful for decode**. + +**Correct pattern**: TensorScatter decode must use `is_causal=0` and rely on `nonpad_kv_seqlen` to bound the active KV range. Models using `is_causal=1` with TensorScatter decode have a spec-invalid combination. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f316a0dfdf91c..5b7624d11c6fd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -33,6 +33,7 @@ struct AttentionParameters { bool broadcast_attn_bias_dim_1 = false; float mask_filter_value = 0.0f; float scale = 0.0f; + float softcap = 0.0f; bool use_tf32 = false; bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 98f92b79e6ec6..60f2d05446da1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -205,11 +205,11 @@ struct GroupQueryAttentionData { void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; - // Unfused fallback buffers (see LaunchGqaUnfusedAttention in gqa_unfused_attention.h): + // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) // unfused_y_bnsh : [B, N_q, S_q, H_v] (output BNSH, transposed to BSNH before leaving op) // unfused_workspace: FP32 QK scratch + T softmax scratch (sized by - // GetGqaUnfusedAttentionWorkspaceSize) + // GetUnfusedAttentionWorkspaceSize) T* unfused_q_bnsh = nullptr; T* unfused_y_bnsh = nullptr; void* unfused_workspace = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 29bb4fba6a09a..aedb370d38367 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -176,7 +176,14 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromBottomRight; + // ONNX spec: is_causal means upper-left alignment (q_i attends to kv[0..i]). + // When past_sequence_length > 0 (decode with KV cache), positions shift → lower-right. + // causal_from_top_left=true: past_seq==0, use CausalFromTopLeft (offset=0). + // causal_from_top_left=false: past_seq>0 or S_q==S_kv, use CausalFromBottomRight + // (offset = num_keys - num_queries, which is 0 when square). + p.custom_mask_type = params.causal_from_top_left + ? Attention::CausalFromTopLeft + : Attention::CausalFromBottomRight; } // We use max_sequence_length to calculate KV stride diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index ace598489a226..a961be051a16a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -13,6 +13,13 @@ namespace cuda { constexpr int kEfficientAttentionMaxHeadSize = 1024; +// CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). +// For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and +// causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. +// -1e+30 is safe: 1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX ≈ 3.4e38, and +// exp(-1e30) ≈ 0 (effectively masked). For fp16 lowest()=-65504 > -1e30, no-op. +constexpr float kCutlassSafeMaskFilterValue = -1.0e+30f; + struct MemoryEfficientAttentionParams { int32_t sm = 50; bool is_half = false; @@ -27,6 +34,12 @@ struct MemoryEfficientAttentionParams { int32_t v_head_size = 0; int32_t local_window_size = -1; bool causal = false; + // When true, causal masking uses upper-left alignment (q_i attends to kv[0..i]). + // When false (default), uses lower-right alignment (q_i attends to kv[kv_len-q_len+i..kv_len-1]). + // ONNX Attention spec requires upper-left for cross-attention without past (S_q != S_kv, past=0). + // Lower-right is correct for decode with KV cache (past > 0). + // For square matrices (S_q == S_kv), both alignments produce identical results. + bool causal_from_top_left = false; bool use_smooth_softmax = false; bool broadcast_attn_bias_dim_0 = false; bool broadcast_attn_bias_dim_1 = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 5f21f3cd34e8f..dfecc2b810a04 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -14,7 +14,7 @@ #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/xqa/xqa_loader.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/debug_macros.h" @@ -513,7 +513,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // GQA-capable unfused fallback (issue #28195). // Activates when Flash / MEA / XQA are all ineligible and KV is not quantized. // Supports any head_size (FP32 QK accumulation), GQA, sliding window, softcap. - // See LaunchGqaUnfusedAttention in contrib_ops/cuda/bert/gqa_unfused_attention.h. + // See LaunchUnfusedAttention in contrib_ops/cuda/bert/unfused_attention.h. // --------------------------------------------------------------------- IAllocatorUniquePtr unfused_scratch; if (!data.use_xqa && !data.use_flash_attention && !data.use_memory_efficient_attention && @@ -538,7 +538,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const SafeInt q_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H * sizeof(T)); const SafeInt y_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H_v * sizeof(T)); const SafeInt ws_bytes = SafeInt( - onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( static_cast(B), static_cast(N_q), static_cast(S_q), static_cast(S_kv))); const SafeInt workspace_offset = q_bnsh_bytes + y_bnsh_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ebb6a0b0da215..70c58e6b8f764 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -38,7 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" @@ -1095,7 +1095,7 @@ Status UnfusedGqaAttention( } // Step 3: run unfused attention with FP32 QK accumulation. - GqaUnfusedAttentionParams p; + UnfusedAttentionParams p; p.batch_size = batch_size; p.num_heads = num_heads; p.kv_num_heads = kv_num_heads; @@ -1113,18 +1113,20 @@ Status UnfusedGqaAttention( p.broadcast_attn_bias_dim_1 = false; p.is_causal = parameters.is_unidirectional; p.local_window_size = parameters.local_window_size; // -1 disables + p.past_kv_length = parameters.total_sequence_length - parameters.sequence_length; p.scale = scale; p.softcap = parameters.softcap; p.seqlens_k = data.total_seq_lens; - ORT_RETURN_IF_ERROR((LaunchGqaUnfusedAttention( + ORT_RETURN_IF_ERROR((LaunchUnfusedAttention( device_prop, cublas, stream, p, data.unfused_q_bnsh, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), /*attn_bias=*/nullptr, data.unfused_y_bnsh, - data.unfused_workspace))); + data.unfused_workspace, + /*output_qk=*/nullptr))); // Step 4: transpose output BNSH → BSNH into data.output. // Use p.v_head_size (== head_size per ORT_ENFORCE) for semantic correctness. diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu index 8aac549aeba01..a0c9d4666cae3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// GQA-capable unfused CUDA attention kernel. See header for contract. +// Unified unfused CUDA attention kernel. See header for contract. +#include #include #include "core/providers/cuda/cu_inc/cub.cuh" #include @@ -13,7 +14,7 @@ #include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" using onnxruntime::cuda::OrtToCudaType; @@ -38,10 +39,37 @@ __device__ __forceinline__ float ToFloat<__half>(__half v) { return __half2float template <> __device__ __forceinline__ float ToFloat<__nv_bfloat16>(__nv_bfloat16 v) { return __bfloat162float(v); } +// Device helper: convert float to T. +template +__device__ __forceinline__ T FromFloat(float v); +template <> +__device__ __forceinline__ float FromFloat(float v) { return v; } +template <> +__device__ __forceinline__ __half FromFloat<__half>(float v) { return __float2half(v); } +template <> +__device__ __forceinline__ __nv_bfloat16 FromFloat<__nv_bfloat16>(float v) { return __float2bfloat16(v); } + inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) { return SafeInt(batch_size) * num_heads * q_seq * total_kv; } +// --------------------------------------------------------------------------- +// CopyQK kernel: copies FP32 QK scratch to T output with scale applied. +// output_qk[i] = T(qk_fp32[i] * scale) for i in [0, total_elements). +// --------------------------------------------------------------------------- +template +__global__ void ScaledCopyQkKernel( + const float* __restrict__ qk_fp32, + T* __restrict__ output_qk, + const float scale, + const int64_t total_elements) { + for (int64_t idx = static_cast(blockIdx.x) * TPB + threadIdx.x; + idx < total_elements; + idx += static_cast(gridDim.x) * TPB) { + output_qk[idx] = FromFloat(qk_fp32[idx] * scale); + } +} + // --------------------------------------------------------------------------- // Softmax kernel: reads FP32 QK scores, writes T softmax output. // @@ -56,7 +84,7 @@ inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total // total_kv_length. Handles fully-masked rows by emitting zeros (no NaN). // --------------------------------------------------------------------------- template -__global__ void GqaUnfusedSoftmaxKernel( +__global__ void UnfusedSoftmaxKernel( const int q_sequence_length, const int total_kv_length, const int num_heads, // N_q @@ -68,6 +96,7 @@ __global__ void GqaUnfusedSoftmaxKernel( const int* __restrict__ seqlens_k, const bool is_causal, const int local_window_size, + const int past_kv_length, const float scale, const float softcap, T* __restrict__ softmax_out) { @@ -82,12 +111,13 @@ __global__ void GqaUnfusedSoftmaxKernel( if (v < kv_end) kv_end = v; if (v < 0) kv_end = 0; } - // past (number of KV positions before the current query tokens) must be - // per-batch when seqlens_k is provided, since different batches can have - // different amounts of valid past context. Using the global total_kv_length - // would over-estimate past for short batches and shift the sliding-window - // start past kv_end, producing an all-masked (zero) row. - const int past = kv_end - q_sequence_length; + // past_kv_length is the number of KV positions that precede the current query + // tokens. For upper-left causal alignment (ONNX Attention with no past), + // this is 0. For lower-right alignment (decode with past), this is + // total_kv_length - q_sequence_length. + // When seqlens_k varies per batch (GQA sliding window), derive per-batch + // so the window cutoff stays within the valid range for shorter batches. + const int past = (seqlens_k != nullptr) ? (kv_end - q_sequence_length) : past_kv_length; const int q_pos = past + q_in_head; int end = kv_end; @@ -191,16 +221,16 @@ __global__ void GqaUnfusedSoftmaxKernel( } template -void LaunchGqaUnfusedSoftmax( +void LaunchUnfusedSoftmax( cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const float* qk_in, const T* attn_bias, T* softmax_out) { const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1); const bool has_bias = (attn_bias != nullptr); constexpr int TPB = 256; - GqaUnfusedSoftmaxKernel<<>>( + UnfusedSoftmaxKernel<<>>( params.q_sequence_length, params.total_kv_length, params.num_heads, @@ -212,6 +242,7 @@ void LaunchGqaUnfusedSoftmax( params.seqlens_k, params.is_causal, params.local_window_size, + params.past_kv_length, params.scale, params.softcap, softmax_out); @@ -250,7 +281,7 @@ template common::Status LaunchQkGemmFp32( const cudaDeviceProp& /*device_prop*/, cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, float* qk_out) { @@ -292,7 +323,7 @@ common::Status LaunchQkGemmFp32( CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention QK GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention QK GEMM failed: ", status); } return common::Status::OK(); } @@ -312,7 +343,7 @@ common::Status LaunchQkGemmFp32( template common::Status LaunchAttnVGemm( cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* softmax_out, const T* value, T* output) { @@ -347,7 +378,7 @@ common::Status LaunchAttnVGemm( CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention AV GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention AV GEMM failed: ", status); } return common::Status::OK(); } @@ -357,10 +388,10 @@ common::Status LaunchAttnVGemm( // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length) { +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length) { const size_t elems = QkElementCount(batch_size, num_heads, q_sequence_length, total_kv_length); // FP32 QK scratch + T softmax scratch. We always allocate sizeof(float) per // element for the T scratch too (upper bound); caller can cast appropriately. @@ -370,26 +401,27 @@ size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, } template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace) { + void* workspace, + T* output_qk) { ORT_RETURN_IF_NOT(params.batch_size > 0 && params.num_heads > 0 && params.kv_num_heads > 0 && params.head_size > 0 && params.v_head_size > 0 && params.q_sequence_length > 0 && params.total_kv_length > 0 && params.max_kv_length >= params.total_kv_length, - "GqaUnfusedAttention: invalid params."); + "UnfusedAttention: invalid params."); ORT_RETURN_IF_NOT(params.num_heads % params.kv_num_heads == 0, - "GqaUnfusedAttention: num_heads (", params.num_heads, + "UnfusedAttention: num_heads (", params.num_heads, ") must be a multiple of kv_num_heads (", params.kv_num_heads, ")."); - ORT_RETURN_IF(workspace == nullptr, "GqaUnfusedAttention: workspace is null."); + ORT_RETURN_IF(workspace == nullptr, "UnfusedAttention: workspace is null."); const size_t elems = QkElementCount(params.batch_size, params.num_heads, params.q_sequence_length, params.total_kv_length); @@ -400,7 +432,21 @@ common::Status LaunchGqaUnfusedAttention( ORT_RETURN_IF_ERROR((LaunchQkGemmFp32(device_prop, cublas, params, query, key, qk_fp32))); - LaunchGqaUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); + // Copy scaled QK to output_qk BEFORE softcap/mask/softmax. + // output_qk[i] = T(qk_fp32[i] * scale) — this is "kQK" mode (scale * Q @ K^T). + // Note: When seqlens_k is provided, positions [seqlens_k[b], total_kv) in output_qk + // may contain stale KV cache data. Consumers of output_qk should only read positions + // [0, seqlens_k[b]) for batch b. + if (output_qk != nullptr) { + const int64_t total = static_cast(elems); + constexpr int kTPB = 256; + constexpr int kMaxBlocks = 65535; + const int blocks = static_cast(std::min(static_cast(kMaxBlocks), (total + kTPB - 1) / kTPB)); + ScaledCopyQkKernel<<>>(qk_fp32, output_qk, params.scale, total); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + + LaunchUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); CUDA_RETURN_IF_ERROR(cudaGetLastError()); ORT_RETURN_IF_ERROR((LaunchAttnVGemm(cublas, params, softmax_T, value, output))); @@ -409,18 +455,18 @@ common::Status LaunchGqaUnfusedAttention( } // Explicit template instantiations. -template common::Status LaunchGqaUnfusedAttention<__half>( +template common::Status LaunchUnfusedAttention<__half>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __half*, const __half*, const __half*, - const __half*, __half*, void*); -template common::Status LaunchGqaUnfusedAttention<__nv_bfloat16>( + const UnfusedAttentionParams&, const __half*, const __half*, const __half*, + const __half*, __half*, void*, __half*); +template common::Status LaunchUnfusedAttention<__nv_bfloat16>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, - const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*); -template common::Status LaunchGqaUnfusedAttention( + const UnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*, __nv_bfloat16*); +template common::Status LaunchUnfusedAttention( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const float*, const float*, const float*, - const float*, float*, void*); + const UnfusedAttentionParams&, const float*, const float*, const float*, + const float*, float*, void*, float*); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.h index 84d645cd2b349..8fb3a18ac7570 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h @@ -13,7 +13,7 @@ namespace contrib { namespace cuda { // ============================================================================ -// GQA Unfused Attention (CUDA fallback for large head_size / fp16 overflow) +// Unified Unfused Attention (CUDA fallback for large head_size / fp16 overflow) // ============================================================================ // // Purpose: @@ -38,18 +38,20 @@ namespace cuda { // - scale is applied to raw QK (before softcap / bias). // - softcap (> 0) is applied after scale: x = softcap * tanh(x / softcap). // - attn_bias (if non-null) is added after softcap (additive mask). -// - causal: k > (past + q) is -inf where past = total_kv - S_q. +// - causal: k > (past_kv_length + q) is -inf. +// When past_kv_length=0 (no past), gives upper-left alignment: q_i attends to kv[0..i]. +// When past_kv_length=total_kv-S_q (decode with past), gives lower-right alignment. // - local_window_size (>= 0): k < (past + q) - local_window_size is -inf. // local_window_size == -1 disables the sliding-window mask. // // The new kernel is suitable only as a fallback when Flash / MEA are ineligible -// (head_size > 256, past_key present with mask, GQA with MHA-only unfused, etc). +// (head_size > 256, past_key present with mask, etc). // The QK GEMM runs with CUBLAS_COMPUTE_32F and writes a FP32 scratch to avoid // fp16 overflow. // // ============================================================================ -struct GqaUnfusedAttentionParams { +struct UnfusedAttentionParams { int batch_size = 0; int num_heads = 0; // N_q int kv_num_heads = 0; // N_kv (num_heads % kv_num_heads == 0) @@ -68,6 +70,7 @@ struct GqaUnfusedAttentionParams { bool is_causal = false; int local_window_size = -1; // -1 disables sliding window + int past_kv_length = 0; // number of past KV positions (for causal alignment) float scale = 1.0f; float softcap = 0.0f; // 0 disables @@ -77,27 +80,30 @@ struct GqaUnfusedAttentionParams { }; // Returns required scratch size in bytes. Caller must allocate -// GetGqaUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length); +// GetUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length); // Compute: Y = softmax(scale * Q * K^T [softcap, causal, window, bias, seqlens_k]) * V. // All pointers are on device. Q/K/V/output are in type T (fp16/bf16/float). // attn_bias (if present) is in type T. +// output_qk (optional): when non-null, writes scale * Q @ K^T (FP32→T) before softcap/mask/softmax. +// Shape: [B, N_q, S_q, total_kv]. Caller allocates. template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace); + void* workspace, + T* output_qk = nullptr); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 228729745b65b..00ce18c65efd8 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,17 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" #include "core/providers/cuda/llm/attention.h" #include "core/providers/cuda/llm/attention_mask_impl.h" -#include "contrib_ops/cuda/bert/attention_data.h" +// attention_impl.h provides Transpose_BNSH_to_BSNH / Transpose_BSNH_to_BNSH used +// by the transpose helpers. #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_kv_cache.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -155,7 +158,12 @@ Status Attention::ConvertAttnMaskToBias( int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer( num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - float mask_filter_value = static_cast(std::numeric_limits::lowest()); + // CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). + // For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and + // causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. + // See kCutlassSafeMaskFilterValue in memory_efficient_attention.h for details. + float mask_filter_value = std::max(static_cast(std::numeric_limits::lowest()), + ::onnxruntime::contrib::cuda::kCutlassSafeMaskFilterValue); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -189,7 +197,7 @@ Status Attention::ConvertAttnMaskToBias( // Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache // Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache // - No mask support (attn_mask rejected at eligibility) -// - 4D BNSH: transposes Q/K/V to BSNH before kernel +// - 4D BNSH: transposes Q to BSNH; new K/V to BSNH for concat (cache stays BNSH) // Path 3: no past, no mask (prompt) -> mha_fwd // Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, attn_mask==nullptr // Note: softcap is passed to the Flash kernel natively. softmax_precision is @@ -334,10 +342,10 @@ Status Attention::RunFlashAttention( ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // TODO(titaiwang): Consolidate preprocessing (RoPE, mask conversion, KV cache concat) into a + // TODO(titaiwang): Consolidate preprocessing (transpose, KV cache concat) into a // single fused kernel like GQA's LaunchUnpackRoPEAppend. Current decode path uses 4-6 kernel - // launches; a fused approach would reduce to ~2, saving ~21μs launch overhead and ~256KB - // intermediate buffer traffic per decode step. + // launches; a fused approach would reduce to ~2, saving launch overhead and intermediate + // buffer traffic per decode step. // Concat past + new KV directly into present buffers using a single fused kernel. // This replaces the old pattern of memset + strided cudaMemcpy2DAsync + Flash's @@ -476,7 +484,7 @@ Status Attention::RunFlashAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // --- Populate present_key/value (BNSH) from K/V (BSNH) --- + // --- Populate present_key/value (BNSH) from K/V (BSNH or BNSH) --- // Skip for decode path where mha_fwd_kvcache already populated present buffers. if (!present_kv_already_populated) { if (present_key != nullptr && is_bsnh) { @@ -528,13 +536,15 @@ Status Attention::RunFlashAttention( // ============================================================================ // // Memory Efficient Attention (cutlass FMHA) dispatch paths: -// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode -// Path 2: no past, with mask (prompt) -> standard MEA with additive bias -// Path 3: no past, no mask (prompt) -> standard MEA +// Path 1: Decode with past KV cache -> LaunchConcatNewToPastKV then standard MEA +// Path 2: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode +// Path 3: Prompt with mask -> standard MEA with additive bias +// Path 4: Prompt without mask -> standard MEA // Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype, -// head_size <= 1024), plus: no output_qk, no past_key (decode excluded), -// bias stride alignment. -// Note: softcap is forwarded to the MEA kernel via p.softcap. softmax_precision +// head_size <= 1024, head_size divisible by 8), plus: no output_qk, bias stride alignment. +// Note: softcap is forwarded to the MEA kernel via p.softcap. CUTLASS applies +// softcap before bias (fused in kernel tiles), matching ONNX spec ordering +// (onnx/onnx#7865): QK → softcap → mask/bias → softmax. softmax_precision // is inherently satisfied (cutlass FMHA accumulates softmax in FP32). // template @@ -546,8 +556,6 @@ Status Attention::RunMemoryEfficientAttention( Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const { #if USE_MEMORY_EFFICIENT_ATTENTION - ORT_UNUSED_PARAMETER(past_key); - ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); auto cuda_stream = Stream(context); const bool is_bsnh = parameters.transpose_output; @@ -582,6 +590,120 @@ Status Attention::RunMemoryEfficientAttention( out_data = out_bsnh_buffer.get(); } + bool present_kv_already_populated = false; + // Track the effective layout of k_data/v_data. Initially matches input layout, + // but changes to BNSH (false) after decode concat into present buffers. + bool kv_is_bsnh = is_bsnh; + + // Scratch buffers for decode concat output when present_key/value are optional. + // Declared at function scope so they outlive the decode block (k_data/v_data may point here). + IAllocatorUniquePtr present_k_scratch; + IAllocatorUniquePtr present_v_scratch; + + // --- Decode path: concat past + new K/V → present buffers (BNSH) --- + // nonpad_kv_seqlen and past_key are mutually exclusive (enforced at validation), + // so the decode path only needs the internal-cache (past_key/present_key) flow. + if (past_key != nullptr) { + ORT_RETURN_IF_NOT(past_value != nullptr, "past_key requires past_value."); + ORT_RETURN_IF_NOT(nonpad_kv_seqlen == nullptr, + "nonpad_kv_seqlen and past_key are mutually exclusive (internal vs external cache)."); + // This mirrors the eligibility check in ComputeInternal — must stay in sync. + ORT_RETURN_IF_NOT(parameters.head_size == parameters.v_head_size, + "MEA decode (past_key) requires head_size == v_head_size for LaunchConcatNewToPastKV."); + + using NativeCudaT = typename OrtToCudaType::type; + + // Allocate scratch buffers for concat output when present_key/value are not requested. + // The concat kernel needs a destination buffer regardless of whether the caller wants present outputs. + T* present_k_data = nullptr; + T* present_v_data = nullptr; + + SafeInt present_k_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.head_size * sizeof(T); + SafeInt present_v_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.v_head_size * sizeof(T); + + if (present_key != nullptr) { + present_k_data = present_key->MutableData(); + } else { + present_k_scratch = GetScratchBuffer(present_k_bytes, GetComputeStream(context)); + present_k_data = static_cast(present_k_scratch.get()); + } + if (present_value != nullptr) { + present_v_data = present_value->MutableData(); + } else { + present_v_scratch = GetScratchBuffer(present_v_bytes, GetComputeStream(context)); + present_v_data = static_cast(present_v_scratch.get()); + } + + // Step 1: Uniform past sequence lengths for the concat kernel. + // ONNX past_key has shape [B, H, past_seq, head_size] — all batches share + // the same past_seq dimension. Bool masks do NOT change where tokens are stored; + // they change which tokens are attended to (via additive bias, handled below). + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); + + // Step 2: Transpose K/V to BSNH if input is 4D BNSH (concat kernel reads new as BSNH). + const T* k_new_bsnh = K->Data(); + const T* v_new_bsnh = V->Data(); + IAllocatorUniquePtr k_bsnh_buffer; + IAllocatorUniquePtr v_bsnh_buffer; + if (!is_bsnh) { + size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.head_size; + size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.v_head_size; + k_bsnh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bsnh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), k_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), v_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + k_new_bsnh = static_cast(k_bsnh_buffer.get()); + v_new_bsnh = static_cast(v_bsnh_buffer.get()); + } + + // Step 3: Fused concat: past_key + new_key → present_key (and same for values). + // One kernel copies past data from [0, past_seq) and new data from BSNH layout + // into present buffer at [past_seq, past_seq + kv_seq), all in BNSH. + // No memset needed: uniform past_seq_lens means every position in the present + // buffer is written by the concat kernel. Padding positions in past_key are copied + // as-is; the attention mask (additive bias) handles correctness at the attention level. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + parameters.batch_size, + parameters.kv_num_heads, + parameters.head_size, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), + /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_k_data), + reinterpret_cast(present_v_data), + cuda_stream, + device_prop.maxThreadsPerBlock, + /*past_only=*/false)); + + // Point MEA's K/V inputs at the concatenated buffers (BNSH). + k_data = present_k_data; + v_data = present_v_data; + kv_is_bsnh = false; + present_kv_already_populated = true; + } + // GQA head expansion: MEA requires matching num_heads for Q/K/V. // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup. const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads; @@ -622,7 +744,7 @@ Status Attention::RunMemoryEfficientAttention( reinterpret_cast(v_data), parameters.total_sequence_length, parameters.total_sequence_length, - is_bsnh, + kv_is_bsnh, cuda_stream, device_prop.maxThreadsPerBlock)); @@ -631,8 +753,8 @@ Status Attention::RunMemoryEfficientAttention( } } - // Note: MEA with past_key/value is handled by the unfused fallback. - // The cascade in ComputeInternal ensures past_key == nullptr when we reach here. + // Note: When past_key is present (decode), k_data/v_data already point to present + // buffers (BNSH) after LaunchConcatNewToPastKV above, so MEA sees the full cache. // Handle attention mask → attention_bias conversion IAllocatorUniquePtr converted_mask_buffer; @@ -642,7 +764,8 @@ Status Attention::RunMemoryEfficientAttention( if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. - // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. + // MEA expects seqlens_k as actual token count, so use FlashSeqlensK variant + // (which converts int64→int32 without subtracting 1). auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), @@ -665,7 +788,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -674,6 +797,15 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // ONNX spec: is_causal means upper-left alignment in the full attention matrix. + // When past_sequence_length == 0 and S_q != S_kv (cross-attention without KV cache), + // queries start at absolute position 0, so causal mask is upper-left. + // When past_sequence_length > 0 (decode with KV cache), queries start at position + // past_seq, so causal mask is effectively lower-right on the [S_q x total_kv] sub-matrix. + // NOTE: For external KV cache (TensorScatter), nonpad_kv_seqlen provides per-batch + // actual lengths and seqlens_k handles the masking — the causal_from_top_left flag + // is only consulted when params.causal is true, so it's correct here. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlen_k_ptr = seqlens_k_buffer.get(); @@ -700,8 +832,12 @@ Status Attention::RunMemoryEfficientAttention( onnxruntime::contrib::cuda::run_memory_efficient_attention(p); // On the MEA (CUTLASS) path (used for both MHA and GQA when nonpad_kv_seqlen is provided), - // zero out output for fully-masked batches to produce zeros (matching Flash behavior). + // zero out output for fully-masked batches to prevent NaN. // CUTLASS epilogue computes 1/s_prime where s_prime=0 for seqlens_k=0, producing NaN. + // TODO(titaiwang): ZeroOutputForFullyMaskedBatches outputs zeros for fully-masked + // batches (seqlens_k=0), which diverges from CPU/Unfused behavior (uniform mean of V). + // For cross-EP consistency, replace with LaunchMeanOfVForFullyMaskedBatches that + // computes mean(V[b,n,:,h]) for each masked batch. See issue #27516. { using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t elements_per_batch = static_cast(parameters.q_sequence_length) * @@ -716,9 +852,10 @@ Status Attention::RunMemoryEfficientAttention( } } // Standard MEA path: float attention bias, bool mask (converted to bias), or no mask. - // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value) - // which correctly handles all-false masks (uniform softmax weights) unlike the - // custom_right_padding seqlens approach which would produce NaN. + // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value). + // For fully-masked batches (all-false bool mask), ConvertAttnMaskToBias uses a capped + // mask_filter_value (-1e+30) that stays finite through CUTLASS's kLog2e multiplication, + // producing correct uniform softmax → mean(V) output. else { if (attn_mask != nullptr) { ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, @@ -731,7 +868,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -740,6 +877,8 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // Causal alignment: same logic as above — upper-left when no past. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.broadcast_attn_bias_dim_0 = broadcast_bias_dim_0; @@ -773,30 +912,33 @@ Status Attention::RunMemoryEfficientAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // Populate present_key/present_value (BNSH) if requested - if (present_key != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.head_size, - K->Data(), present_key->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_key != nullptr && !is_bsnh) { - // 4D BNSH prompt: K is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_key->MutableData(), K->Data(), - K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); - } - if (present_value != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.v_head_size, - V->Data(), present_value->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_value != nullptr && !is_bsnh) { - // 4D BNSH prompt: V is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_value->MutableData(), V->Data(), - V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + // Populate present_key/present_value (BNSH) if requested. + // Skip for decode path where LaunchConcatNewToPastKV already populated present buffers. + if (!present_kv_already_populated) { + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } } return Status::OK(); @@ -819,250 +961,30 @@ Status Attention::RunMemoryEfficientAttention( } // ============================================================================ -// RunUnfusedAttention: Delegates to MHA's QkvToContext (unfused GEMM+softmax+GEMM) -// ============================================================================ -// -// Unfused Attention dispatch paths: -// Universal fallback via MHA's QkvToContext. -// Path 1: nonpad_kv_seqlen only -> converts to attention_bias [B, q_seq, total_seq] -// Path 2: nonpad_kv_seqlen + attn_mask -> composes both into attention_bias [B, q_seq, total_seq] -// (nonpad bias + mask bias added element-wise with cyclic broadcasting) -// Path 3: all other cases -> passes mask/bias directly -// Supports: all dtypes (fp16/bf16/fp32), all mask types (bool/float/none), all head sizes -// Not supported: softcap (rejected at fallback), output_qk modes beyond kNone/kQK -// Limitation: MHA only (q_num_heads must equal kv_num_heads) -// -template -Status Attention::RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const { - using CudaT = typename ToCudaType::MappedType; - // OrtToCudaType maps BFloat16 → __nv_bfloat16 (native HW type), matching kernel instantiations. - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; - auto& device_prop = GetDeviceProp(); - auto cuda_stream = Stream(context); - auto ort_stream = GetOrtStream(context); - - // Bridge to contrib::AttentionParameters for the MHA unfused path - onnxruntime::contrib::AttentionParameters contribop_parameters; - - if (!parameters.transpose_output) { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - contribop_parameters.is_output_bnsh = true; - } else { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; - contribop_parameters.is_output_bnsh = false; - } - - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - - // Determine broadcast flags for attention_bias - if (attn_mask != nullptr) { - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - if (attn_mask_dims_size == 2) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1; - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; - } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; - } - - // Construct AttentionData - onnxruntime::contrib::cuda::AttentionData data; - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (output_qk != nullptr) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); - } - data.bias = nullptr; - - // Handle attention mask / nonpad_kv_seqlen → attention_bias - IAllocatorUniquePtr converted_mask_buffer; - IAllocatorUniquePtr mask_bias_buffer; // temp buffer for mask→bias when composing - if (nonpad_kv_seqlen != nullptr) { - // Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq] - int64_t bias_elements = static_cast(parameters.batch_size) * - parameters.q_sequence_length * - parameters.total_sequence_length; - converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( - nonpad_kv_seqlen->Data(), - reinterpret_cast(converted_mask_buffer.get()), - parameters.batch_size, - parameters.q_sequence_length, - parameters.total_sequence_length, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - - // When attn_mask is also present, compose it into the nonpad bias additively. - // The nonpad bias is [B, q, t]; the mask is added with cyclic broadcasting - // (e.g. a 2D [q, t] mask repeats over the batch dimension). - // Only 2D masks and 4D masks with head_dim=1 are supported — per-head masks - // (3D [H,q,t] or 4D [B,H>1,q,t]) cannot be composed into a [B,q,t] buffer. - if (attn_mask != nullptr) { - const auto& mask_shape = attn_mask->Shape(); - int mask_dims = static_cast(mask_shape.NumDimensions()); - ORT_ENFORCE(mask_dims == 2 || (mask_dims == 4 && mask_shape[1] == 1), - "nonpad_kv_seqlen + attn_mask composition in unfused path only supports " - "2D masks [q, t] and 4D masks with head_dim=1 [B, 1, q, t]. " - "Got mask shape: ", - mask_shape); - - int64_t mask_elements = mask_shape.Size(); - const NativeCudaT* mask_bias_ptr = nullptr; - - if (attn_mask->IsDataType()) { - // Convert bool mask to additive bias in a temp buffer, then add in-place. - mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(mask_bias_buffer.get()), - mask_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - mask_bias_ptr = reinterpret_cast(mask_bias_buffer.get()); - } else { - // Float mask is already in additive bias format. - mask_bias_ptr = reinterpret_cast(attn_mask->Data()); - } - - // Add mask bias into nonpad bias with cyclic broadcasting. - // 2D mask [q, t]: mask_elements = q*t, repeats for each batch → correct. - // 4D mask [B, 1, q, t]: mask_elements = B*q*t = bias_elements → direct add. - ORT_RETURN_IF_ERROR(LaunchAddBiasInPlace( - reinterpret_cast(converted_mask_buffer.get()), - mask_bias_ptr, - bias_elements, - mask_elements, - cuda_stream, - device_prop.maxThreadsPerBlock)); - } - - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - // Composed bias is [B, q_seq, total_seq] → broadcasts over heads but not batch. - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask != nullptr) { - if (attn_mask->IsDataType()) { - int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(converted_mask_buffer.get()), - num_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - } else { - data.attention_bias = reinterpret_cast(attn_mask->Data()); - } - } - - data.qkv_format = contribop_parameters.qkv_format; - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, false, false, false, false, false, - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); - - // Note: unfused attention produces valid finite output (mean-of-V via uniform softmax) - // for fully-masked batches, so ZeroOutput is not needed here. Only MEA requires - // ZeroOutput to prevent NaN from the CUTLASS epilogue's 1/s_prime division. - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); -} - -// ============================================================================ -// RunGqaUnfusedAttention: GQA-capable unfused path + large-head fp16/bf16 fix +// RunUnfusedAttention: Unified unfused path for both MHA and GQA // ============================================================================ // -// Routes to LaunchGqaUnfusedAttention from contrib_ops/cuda/bert/gqa_unfused_attention.h. +// Routes to LaunchUnfusedAttention from contrib_ops/cuda/bert/unfused_attention.h. // // Handles: +// - MHA as a degenerate case (group_size=1, no head expansion needed). // - GQA natively (no K/V head replication; reshape-Q trick inside kernel). // - fp16/bf16 with large head_size via FP32 QK scratch (fixes issue #28195: // unfused attention producing NaN when head_dim > 256 at scale=1.0). // - Different Q/K sequence lengths, past_key+past_value, nonpad_kv_seqlen. // - attn_mask (bool/float, 2D/3D/4D), causal, softcap. // -// Not supported here (caller rejects upstream): -// - output_qk: only MHA unfused emits QK, so this path requires output_qk==nullptr. +// Not supported (returns NOT_IMPLEMENTED upstream): +// - qk_matmul_output_mode beyond kNone/kQK (kQKMask, kQKSoftCap, kQKSoftMax). // ============================================================================ template -Status Attention::RunGqaUnfusedAttention( +Status Attention::RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const { using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; auto& device_prop = GetDeviceProp(); @@ -1108,9 +1030,6 @@ Status Attention::RunGqaUnfusedAttention( ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // LaunchConcatNewToPastKV uses a single head_size for both K and V caches. - ORT_RETURN_IF(H != H_v, - "RunGqaUnfusedAttention: past_key with H != H_v not supported"); auto past_seqlens_buffer = GetScratchBuffer(B, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, B, @@ -1134,17 +1053,51 @@ Status Attention::RunGqaUnfusedAttention( v_new_bsnh = static_cast(v_bnsh_buffer.get()); } - ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( - B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, - /*is_bsnh=*/false, - past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, - reinterpret_cast(past_key->Data()), - reinterpret_cast(past_value->Data()), - reinterpret_cast(k_new_bsnh), - reinterpret_cast(v_new_bsnh), - reinterpret_cast(present_key->MutableData()), - reinterpret_cast(present_value->MutableData()), - cuda_stream, max_threads, /*past_only=*/false)); + if (H == H_v) { + // K and V have the same head_size -- single concat call handles both. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_key->MutableData()), + reinterpret_cast(present_value->MutableData()), + cuda_stream, max_threads, /*past_only=*/false)); + } else { + // H != H_v: LaunchConcatNewToPastKV uses a single head_size for both K and V + // (grid Z=0 for K, Z=1 for V with the same block dims). We must call it + // twice with different head_size values -- once for K (head_size=H) and once + // for V (head_size=H_v). Each call duplicates K data into V params (or vice + // versa) so both Z indices write to the same buffer harmlessly. + // + // Trade-off: each call does 2× GPU work (both Z slices execute). This is + // acceptable because H!=H_v decode through MEA is rare, and modifying the + // shared kernel (contrib_ops/cuda/bert/attention_kv_cache.cu) to support + // nullptr outputs or K-only/V-only modes would risk breaking GQA callers. + auto* pk = reinterpret_cast(past_key->Data()); + auto* pv = reinterpret_cast(past_value->Data()); + auto* nk = reinterpret_cast(k_new_bsnh); + auto* nv = reinterpret_cast(v_new_bsnh); + auto* out_k = reinterpret_cast(present_key->MutableData()); + auto* out_v = reinterpret_cast(present_value->MutableData()); + // Concat K with head_size=H (V params duplicate K data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pk, pk, nk, nk, out_k, out_k, + cuda_stream, max_threads, /*past_only=*/false)); + // Concat V with head_size=H_v (K params duplicate V data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H_v, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pv, pv, nv, nv, out_v, out_v, + cuda_stream, max_threads, /*past_only=*/false)); + } k_cache = reinterpret_cast(present_key->MutableData()); v_cache = reinterpret_cast(present_value->MutableData()); present_already_populated = true; @@ -1214,12 +1167,12 @@ Status Attention::RunGqaUnfusedAttention( } // -------- Allocate kernel workspace ----------------------------------------- - const size_t ws_bytes = onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + const size_t ws_bytes = onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( B, N_q, S_q, total_kv); auto ws_buffer = GetScratchBuffer(ws_bytes, GetComputeStream(context)); // -------- Call the kernel --------------------------------------------------- - onnxruntime::contrib::cuda::GqaUnfusedAttentionParams p; + onnxruntime::contrib::cuda::UnfusedAttentionParams p; p.batch_size = B; p.num_heads = N_q; p.kv_num_heads = N_kv; @@ -1232,13 +1185,19 @@ Status Attention::RunGqaUnfusedAttention( p.broadcast_attn_bias_dim_1 = bcast1; p.is_causal = parameters.is_causal; p.local_window_size = -1; // ONNX Attention (opset 23/24) does not expose sliding window. + p.past_kv_length = parameters.past_sequence_length; p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlens_k = seqlens_k_ptr; - ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchGqaUnfusedAttention( + NativeCudaT* output_qk_data = (output_qk != nullptr) + ? reinterpret_cast(output_qk->MutableData()) + : nullptr; + + ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchUnfusedAttention( device_prop, GetCublasHandle(context), cuda_stream, - p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get()))); + p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get(), + output_qk_data))); // -------- Transpose output BNSH -> BSNH if input was 3D -------------------- if (is_bsnh && out_bnsh_buffer != nullptr) { @@ -1279,10 +1238,10 @@ Status Attention::RunGqaUnfusedAttention( // ============================================================================ // ComputeInternal: Dispatch to appropriate attention kernel // ============================================================================ -// MHA path (q_num_heads == kv_num_heads): uses direct kernel dispatch cascade -// flash → memory efficient → unfused -// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively), MEA -// (with head expansion via LaunchUngroup, fp16/bf16 only), or GQA unfused fallback. +// Dispatch cascade: Flash → MEA (Memory Efficient) → Unified Unfused Attention. +// The unified unfused kernel handles both MHA (num_heads == kv_num_heads) and +// GQA (num_heads != kv_num_heads) via a reshape-Q trick (no K/V head replication). +// MEA uses head expansion via LaunchUngroup (fp16/bf16 only) for GQA. // ============================================================================ template Status Attention::ComputeInternal(OpKernelContext* context) const { @@ -1331,12 +1290,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Flash: strictly requires BSNH — Q is transposed BNSH→BSNH before calling mha_fwd*. // K/V passed as BNSH to mha_fwd_kvcache (it handles both layouts). // MEA: accepts both BSNH and BNSH natively via is_kv_bsnh flag. Q transposed to BSNH. - // Unfused: accepts both via QkvToContext's qkv_format (Q_K_V_BSNH or Q_K_V_BNSH). + // Unfused: accepts both BSNH and BNSH (transposes if needed). // // nonpad_kv_seqlen + attn_mask routing: // Flash: cannot handle this combo (no bias param when seqlens_k is used) → excluded. // MEA: supports both (custom_right_padding for seqlens + additive attn_bias for mask). - // Unfused: nonpad → attention_bias; mask composed additively when both present. + // Unfused: nonpad → seqlens_k; mask → attention_bias; both handled independently in softmax kernel. #if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION const bool has_output_qk = (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone); #endif @@ -1347,6 +1306,39 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // softmax_precision=0 (default) is also fine since higher precision is always // acceptable per the ONNX spec. + // Flash Attention uses lower-right (bottom-right) causal alignment with no option for + // upper-left. The ONNX spec requires upper-left alignment when there is no past context: + // query[0] attends only to key[0]. The difference only manifests when S_q != S_kv + // (cross-attention shape) with no past. Skip Flash for this case; MEA handles it correctly + // via the causal_from_top_left flag, and Unified Unfused uses past_kv_length=0. + // Defined here for visibility — only Flash needs this guard (MEA/Unfused handle upper-left natively). + const bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + + // Reject causal + TensorScatter decode (S_q < S_kv without past_key). + // Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + // only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + // q[0] sees only kv[0] — not meaningful for autoregressive generation. + // + // Why is_causal=0 is correct for external cache decode: + // - With S_q=1, there's only one query position at the end of the sequence + // - All KV positions are in the "past" relative to this query — nothing to mask + // - nonpad_kv_seqlen already bounds attention to valid cache positions + // + // For external cache prompt (S_q == S_kv), is_causal=1 works correctly (square matrix, + // upper-left == lower-right). For chunked prefill (S_q > 1 but S_q < S_kv), use an + // explicit attn_mask instead of is_causal. + if (causal_cross_no_past && nonpad_kv_seqlen != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Causal attention with TensorScatter (nonpad_kv_seqlen) and S_q != S_kv without " + "past_key is not supported. Per ONNX spec, is_causal without past_key produces " + "upper-left alignment where q[i] only attends to kv[0..i], which for decode (S_q=1) " + "means q[0] sees only kv[0]. Use is_causal=0 for TensorScatter decode; the KV bounds " + "are already enforced by nonpad_kv_seqlen without needing a causal mask. For chunked " + "prefill with external cache, use an explicit attn_mask instead."); + } + #if USE_FLASH_ATTENTION { auto& device_prop = GetDeviceProp(); @@ -1357,16 +1349,16 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.q_num_heads, parameters.kv_num_heads) && parameters.head_size == parameters.v_head_size && !has_output_qk && - // Flash does not support attention masks (no bias parameter in mha_fwd/mha_fwd_kvcache). - // Bool attn_mask + past_key is rejected because Flash uses paged KV cache semantics - // that produce spec-divergent present_kv layout for partial masks (e.g. [T,T,T,F]). - // Unfused handles bool+past_key spec-correctly via standard ConcatPastToPresent. - // TODO(titaiwang): GQA + bool attn_mask + past_key currently has no runner (Flash - // rejected here, unfused doesn't support GQA, MEA blocked by past_key != nullptr). - // Once PR #27851 merges (MEA supports past_key), this gap will be covered. + !causal_cross_no_past && + // Flash does not support attention masks — reject when attn_mask is present. attn_mask == nullptr; if (flash_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Flash Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") << ")"; return RunFlashAttention(context, Q, K, V, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } @@ -1383,7 +1375,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - past_key == nullptr && + // MEA decode requires head_size == v_head_size for LaunchConcatNewToPastKV + // (single head_size parameter). Fall back to unfused when they differ. + (past_key == nullptr || parameters.head_size == parameters.v_head_size) && // GQA+MEA requires LaunchUngroup which only has fp16/bf16 instantiations. // FP32 GQA must fall through to the unfused path. !(is_gqa && std::is_same::value); @@ -1408,65 +1402,43 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } if (mea_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Memory Efficient Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") + << ", mask=" << (attn_mask != nullptr ? "yes" : "no") << ")"; return RunMemoryEfficientAttention(context, Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } } #endif - // TODO(titaiwang): Support additional output_qk modes beyond kNone and kQK. - // Currently only unfused handles output_qk, and only kNone/kQK modes. + // Fallback: unified unfused attention + // Routes ALL cases to LaunchUnfusedAttention, which handles: + // - GQA natively (reshape-Q trick inside kernel, no K/V head replication) + // - MHA as a degenerate case (group_size=1) + // - fp16/bf16 with large head_size via FP32 QK scratch + // - softcap, attn_mask, causal, past_key+past_value, nonpad_kv_seqlen + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax) + // - past_key with H != H_v (separate concat calls for K and V) + + // Guard: unified kernel only supports kNone and kQK output modes. + // Other modes (kQKMask, kQKSoftCap, kQKSoftMax) expect QK values captured at + // different pipeline stages that the unified kernel does not implement. if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qk_matmul_output_mode other than kNone and kQK is not supported yet " - "in Attention op (CUDA)."); - } - - // GQA-capable unfused fallback (issue #28195). - // Routes through LaunchGqaUnfusedAttention when: - // - GQA (q_num_heads != kv_num_heads) — the MHA unfused runner cannot handle this. - // - fp16/bf16 with head_size > 128 — raw Q*K^T can overflow fp16 storage even - // though cuBLAS accumulates in FP32; the new kernel writes QK to an FP32 scratch. - // The overflow threshold depends on the distribution of Q/K values and scale. - // head_size=256 at scale=1/sqrt(256)=0.0625 is borderline; head_size=512 at - // scale=1.0 (Gemma 4) definitely overflows. We use 128 as a conservative - // threshold since all fused kernels already handle head_size <= 128 anyway. - // This kernel supports softcap. It does not support output_qk, so we only enter it - // when qk_matmul_output_mode_ == kNone. - const bool is_half_or_bf16 = std::is_same::value || std::is_same::value; - const bool needs_fp32_qk_scratch = is_half_or_bf16 && parameters.head_size > 128; - if ((is_gqa || needs_fp32_qk_scratch) && - qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone) { - LOGS_DEFAULT(VERBOSE) << "Attention: using GQA unfused fallback (is_gqa=" << is_gqa - << ", needs_fp32_qk_scratch=" << needs_fp32_qk_scratch - << ", head_size=" << parameters.head_size - << ", softcap=" << parameters.softcap << ")"; - return RunGqaUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, parameters); - } - - if (is_gqa) { - // qk_matmul_output_mode != kNone reaches here; the unfused MHA runner cannot handle GQA. - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "ONNX Attention with GQA (q_num_heads != kv_num_heads) and output_qk is not " - "supported by the unfused runner."); - } - - // Fallback: unfused MHA attention (legacy runner). - // Softcap is not implemented in the legacy unfused path — it requires Flash or MEA - // (or the new GQA unfused path above, which supports softcap for fp16/bf16/fp32). - // NOTE: keep this guard even if future PRs add softcap to more fused paths — this - // legacy unfused runner does NOT apply softcap and would silently produce wrong results. - if (parameters.softcap > 0.0f) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "softcap requires flash attention or memory efficient attention, " - "but neither is eligible for this configuration. Check dtype (fp16/bf16 required for Flash), " - "head_size constraints, and past_key compatibility."); + "Only kNone and kQK output modes are supported in unified unfused attention. Mode: ", + static_cast(qk_matmul_output_mode_)); } + LOGS_DEFAULT(VERBOSE) << "Attention: using unified unfused path (is_gqa=" << is_gqa + << ", head_size=" << parameters.head_size + << ", softcap=" << parameters.softcap << ")"; return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, output_qk, parameters); + nonpad_kv_seqlen, Y, present_key, present_value, + output_qk, parameters); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index 2acbf3b2ed829..f11503f154a30 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -31,27 +31,18 @@ class Attention final : public CudaKernel { Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const; - Status RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const; - - // GQA-capable unfused fallback. Handles: + // Unified unfused fallback. Handles: // - GQA (q_num_heads != kv_num_heads) without K/V head replication. // - fp16/bf16 with large head_size (FP32 QK accumulation, fixes #28195). // - past_key+past_value, attn_mask (bool/float), nonpad_kv_seqlen. - // Does not support: output_qk - // (output_qk modes other than kNone are rejected upstream). - Status RunGqaUnfusedAttention( + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax). + Status RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const; Status ConvertAttnMaskToBias( diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index 4ab3990b2f85d..2ba7f2e1a9836 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -89,107 +89,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( return CUDA_CALL(cudaGetLastError()); } -// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias. -// Generates (batch_size, q_seq_len, total_seq_len) output where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel( - const int64_t* __restrict__ nonpad_kv_seqlen, - T* __restrict__ attention_bias, - const int batch_size, - const int q_seq_len, - const int total_seq_len, - const float mask_filter_value) { - int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - for (; idx < total; idx += static_cast(gridDim.x) * blockDim.x) { - int b = static_cast(idx / (static_cast(q_seq_len) * total_seq_len)); - int t = static_cast(idx % total_seq_len); - int64_t valid_len = nonpad_kv_seqlen[b]; - CUDA_KERNEL_ASSERT(valid_len >= 0 && valid_len <= static_cast(total_seq_len)); - valid_len = max(static_cast(0), min(valid_len, static_cast(total_seq_len))); - attention_bias[idx] = (t < static_cast(valid_len)) ? T(0.0f) : T(mask_filter_value); - } -} - -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block) { - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - if (total == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); - int64_t blocks = (total + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - ConvertNonpadKvSeqlenToAttentionBiasKernel<<>>( - nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t*, float*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>( - const int64_t*, __half*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>( - const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int); - -// Add an addend bias into an existing bias buffer using cyclic broadcasting. -// Used to compose nonpad_kv_seqlen bias [B, q, t] with an attn_mask bias that -// is smaller or equal (e.g. 2D [q, t] cyclic-broadcasts over batch dimension). -template -__global__ void AddBiasInPlaceKernel( - T* __restrict__ bias, - const T* __restrict__ addend, - int64_t total_elements, - int64_t addend_elements) { - for (int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - idx < total_elements; - idx += static_cast(gridDim.x) * blockDim.x) { - float sum = static_cast(bias[idx]) + static_cast(addend[idx % addend_elements]); - bias[idx] = T(sum); - } -} - -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block) { - if (total_elements == 0 || addend_elements == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total_elements)); - int64_t blocks = (total_elements + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - AddBiasInPlaceKernel<<>>( - bias, addend, total_elements, addend_elements); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchAddBiasInPlace(float*, const float*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__half>(__half*, const __half*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, int64_t, int64_t, cudaStream_t, int); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // CUTLASS MEA epilogue computes 1/s_prime where s_prime=0 → NaN for fully-masked // batches. The unfused path produces uniform softmax weights (finite mask_filter_value, diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index 1ada783e9d64d..d2cb4dbbd25ae 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -31,34 +31,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( cudaStream_t stream, int max_threads_per_block); -// Convert nonpad_kv_seqlen to an additive attention bias for the MHA unfused path. -// Generates a (batch_size, q_seq_len, total_seq_len) tensor where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block); - -// Additively compose an addend bias into an existing bias buffer in-place. -// Supports cyclic broadcasting: addend of size [q, t] is repeated over batch -// to compose with a bias of size [B, q, t]. When both have the same number -// of elements (e.g. 4D mask [B, 1, q, t]), it performs a direct element-wise add. -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // Used in the MEA path only: CUTLASS epilogue computes 1/s_prime where s_prime=0, // producing NaN for fully-masked batches. This kernel overwrites those NaN outputs diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 0cf95141b7a6c..40c45db2dfd66 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -8,6 +8,8 @@ #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/scoped_env_vars.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace test { @@ -91,8 +93,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, y, false, 0, 3e-5f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, present_key); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, present_value); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, qk_matmul_output); } else if (tensor_type == TensorType::kFloat16) { @@ -120,8 +126,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, ToFloat16(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, ToFloat16(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, ToFloat16(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, ToFloat16(qk_matmul_output)); } else { @@ -149,8 +159,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, FloatsToBFloat16s(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, FloatsToBFloat16s(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, FloatsToBFloat16s(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, FloatsToBFloat16s(qk_matmul_output)); } @@ -516,11 +530,10 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { // Regression guard: all-false bool mask in decode mode (past_sequence_length > 0). // Guards against a bug where fully-masked batches produce NaN or incorrect output. -// Expected behavior: uniform softmax over past KV values produces Y = mean-of-V. -// With past_v = [10,20,30,40] and [20,40,60,80] per head, and all positions masked out, -// softmax(all -inf + constant mask_filter_value) → uniform weights → Y = {25, 50}. -// This test originally came from upstream/main and validates that both CPU and CUDA -// (unfused path) handle the all-false mask case identically. +// Expected behavior: uniform softmax over all KV values produces Y = mean-of-V. +// On CUDA, MEA decode handles this config (total_seq=4, 4-aligned). The capped +// mask_filter_value (-1e+30) in ConvertAttnMaskToBias prevents CUTLASS overflow, +// producing correct uniform softmax → mean(V). TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { int batch_size = 1; int q_num_heads = 2; @@ -609,8 +622,9 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { ); } -// Unfused decode path with fp16 and all-true bool attention mask. -// Flash rejects attn_mask (requires attn_mask==nullptr), so CUDA routes to unfused. +// Decode path with fp16 and all-true bool attention mask. +// Flash rejects attn_mask (requires attn_mask==nullptr). MEA handles decode with +// bool mask via additive bias (past_key concat + ConvertAttnMaskToBias). // head_size=64. Uniform keys make output analytically verifiable: // all attention scores are equal, so softmax is uniform over all positions. TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { @@ -695,8 +709,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { // Decode with partial bool mask [T,T,T,F]: the new token is masked out. // With mask [T,T,T,F] past_seq=3 total=4: only positions 0,1,2 are attended (past only). -// Flash is ineligible (bool+past_key rejected), so CUDA uses unfused which handles this -// spec-correctly via standard ConcatPastToPresent + element-wise mask application. +// Flash is ineligible (bool+past_key rejected). MEA handles decode with bool mask +// via additive bias (past_key concat + ConvertAttnMaskToBias). // Y = uniform mean over the 3 attended past values (Q=K=constant → uniform softmax). // CPU always runs; CUDA runs when SM 5.3+ is available. TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { @@ -781,7 +795,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { // Multi-batch decode with per-batch partial bool masks. // batch_size=2: batch 0 [T,T,T,F,F,F] (3 leading trues), batch 1 [T,T,T,T,T,T] (all true). -// Flash is ineligible (bool+past_key rejected), CUDA uses unfused. +// Flash is ineligible (bool+past_key rejected). MEA rejected by CUTLASS bias alignment +// (total_seq=6, 6%4≠0), so CUDA falls through to unfused. // Unfused applies standard ConcatPastToPresent (new token at position past_sequence_length=5 // for all batches) and element-wise mask in softmax. // Runs on both CPU and CUDA to verify cross-EP consistency. @@ -988,9 +1003,8 @@ TEST(AttentionTest, Attention4DSoftCap) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1018,9 +1032,8 @@ TEST(AttentionTest, Attention4DSoftCapFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1160,7 +1173,6 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -1250,7 +1262,6 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { int batch_size = 2; // Q.shape[0] int q_num_heads = 1; // Q.shape[1] @@ -2308,10 +2319,10 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with FP32 QK accumulation for large head_size (> 128). -// This exercises the RunGqaUnfusedAttention path in attention.cc which uses +// Unfused attention with FP32 QK accumulation for large head_size (> 128). +// This exercises the RunUnfusedAttention path in attention.cc which uses // an FP32 scratch buffer for QK matmul to prevent overflow in fp16. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2371,9 +2382,9 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with causal mask and large head_size. -// Verifies that is_causal works correctly in the unfused GQA path. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { +// Unfused attention with causal mask and large head_size. +// Verifies that is_causal works correctly in the unfused path. +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_Causal_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2440,8 +2451,8 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with past_key + attn_mask: exercises concat + bias path together. -TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { +// Unfused with past_key + attn_mask: exercises concat + bias path together. +TEST(AttentionTest, Attention_Unfused_PastKey_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2519,8 +2530,8 @@ TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with softcap + attn_mask: verifies the softcap + bias interaction. -TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { +// Unfused with softcap + attn_mask: verifies the softcap + bias interaction. +TEST(AttentionTest, Attention_Unfused_Softcap_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2572,8 +2583,8 @@ TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with BSNH (3D) input: previous tests all use 4D BNSH input. -TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { +// Unfused with BSNH (3D) input: previous tests all use 4D BNSH input. +TEST(AttentionTest, Attention_Unfused_BSNH_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2622,8 +2633,8 @@ TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with fp32: exercises the float template instantiation. -TEST(AttentionTest, Attention_GqaUnfused_FP32) { +// Unfused with fp32: exercises the float template instantiation. +TEST(AttentionTest, Attention_Unfused_FP32) { if (!HasCudaEnvironment(0)) { return; } @@ -2673,5 +2684,296 @@ TEST(AttentionTest, Attention_GqaUnfused_FP32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// Test MEA decode path by disabling Flash Attention. +// Uses the same Attention4DDefaultBasic data (head_size == v_head_size, fp16 with past_key) +// but forces MEA runner via environment variable. +TEST(AttentionTest, Attention4DMEADecodeFloat16) { + int batch_size = 2; + int q_num_heads = 3; + int q_sequence_length = 4; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 3; + int v_head_size = 8; + int past_sequence_length = 5; + + // Simple test data: one-hot Q/K/V to make expected output predictable + size_t q_size = batch_size * q_num_heads * q_sequence_length * head_size; + size_t k_size = batch_size * kv_num_heads * kv_sequence_length * head_size; + size_t v_size = batch_size * kv_num_heads * kv_sequence_length * v_head_size; + + std::vector q(q_size, 0.0f); + q[0] = 1.0f; // first element of first query is 1 + std::vector k(k_size, 0.0f); + k[0] = 1.0f; // first element of first key is 1 + std::vector v(v_size, 0.0f); + v[0] = 1.0f; // first element of first value is 1 + + // Expected output matches Attention4DDefaultBasic (same data, same math regardless of runner) + std::vector y = {0.221683f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + // Force MEA by disabling Flash Attention + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDisableFlashAttention, "1"}}}; + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, + y, std::vector(), std::vector(), std::vector(), + true, false, true // disable_cpu, disable_cuda=false (test CUDA MEA), disable_dml + ); +} + +// Regression test for output_qk + softcap: verifies that qk_matmul_output_mode=0 (kQK) +// returns RAW Q*K logits (before softcap), not softcapped values. +// This test would FAIL if CopyQK were moved after ApplySoftcap: +// - Correct (CopyQK before softcap): output_qk = 2.0 (raw dot product) +// - Wrong (CopyQK after softcap): output_qk = tanh(2.0) ≈ 0.964 (clamped by softcap=1.0) +// Uses constant Q=1, K=1 with head_size=4 so QK = scale * dot(Q,K) = 0.5 * 4 = 2.0. +// v_head_size(6) != head_size(4) blocks Flash Attention and MEA decode, forcing unfused path. +TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { + int batch_size = 1; + int q_num_heads = 2; + int q_sequence_length = 2; + int head_size = 4; + int kv_sequence_length = 3; + int kv_num_heads = 2; + int v_head_size = 6; + int past_sequence_length = 0; + int total_sequence_length = past_sequence_length + kv_sequence_length; + + // Constant Q and K: all 1.0 + // QK = scale * dot(Q[i], K[j]) = (1/sqrt(4)) * 4 = 2.0 for all (i,j) pairs + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 1.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 1.0f); + + // V: position j gets value (j+1)*0.1 across all v_head_size dims + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < v_head_size; h++) { + v[(n * kv_sequence_length + s) * v_head_size + h] = val; + } + } + } + + // Expected output_qk: raw QK logits = 2.0 for all entries + // Shape: [batch, q_num_heads, q_seq, total_seq] = [1, 2, 2, 3] = 12 values + std::vector expected_qk(batch_size * q_num_heads * q_sequence_length * total_sequence_length, 2.0f); + + // Expected Y: softcap(2.0) ≈ 0.964 for all QK → uniform softmax → Y = mean(V) = 0.2 + // Shape: [batch, q_num_heads, q_seq, v_head_size] = [1, 2, 2, 6] = 24 values + std::vector ys(batch_size * q_num_heads * q_sequence_length * v_head_size, 0.2f); + + // present_key = K (no past), present_value = V (no past) + // These must be provided so the OpTester has all 4 outputs for correct index mapping. + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, 0, std::numeric_limits::quiet_NaN(), 1.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode=kQK, scale=default, softcap=1.0 + ys, k, v, expected_qk, + false, false, true // disable_cpu, disable_cuda, disable_dml — runs on both CPU and CUDA unfused (v_head_size != head_size blocks Flash/MEA) + ); +} + +// ============================================================================ +// Causal alignment tests: verify upper-left (no past) vs lower-right (with past) +// These are CUDA-only tests that validate the causal masking fix. +// ============================================================================ + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// V is identity-like so output directly reveals which KV positions were attended. +// Exercises MEA (fp32, head_size divisible by 4) or Unfused kernel on CUDA. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 4; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // Upper-left causal (scale=0.5): q0→v[0]=[1,0,0,0], q1→softmax([0.47,0.375])@v[0:2], q2→softmax([0.6,0.48,0.495])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) with head_size=8. +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// head_size=8 targets the MEA path (below Flash minimum of 32) but validates +// correctness regardless of which kernel handles it. head_size=8 satisfies +// MEA's head_size%8==0 requirement, so this exercises MEA's CausalFromTopLeft +// path (via causal_from_top_left=true when past_seq==0). +// V is identity-like so output directly reveals which KV positions were attended. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeftSmallHead) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 8; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, 0.8f, 0.4f, 0.6f, 0.1f, + 0.4f, 0.8f, 0.1f, 0.6f, 0.3f, 0.7f, 0.2f, 0.9f, + 0.7f, 0.3f, 0.9f, 0.5f, 0.1f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, 0.1f, 0.3f, 0.5f, 0.7f, + 0.1f, 0.3f, 0.5f, 0.7f, 0.9f, 0.2f, 0.4f, 0.6f, + 0.9f, 0.1f, 0.2f, 0.3f, 0.4f, 0.8f, 0.7f, 0.5f, + 0.5f, 0.6f, 0.7f, 0.8f, 0.2f, 0.4f, 0.3f, 0.1f, + 0.3f, 0.2f, 0.1f, 0.4f, 0.6f, 0.5f, 0.8f, 0.9f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}; + // Upper-left causal (scale=1/sqrt(8)): q0→v[0], q1→softmax(scaled_scores[0:2])@v[0:2], q2→softmax(scaled_scores[0:3])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.511488f, 0.488512f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.344711f, 0.305668f, 0.349621f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} +// Lower-right alignment: q0 at absolute position 4 attends to all 5 KV positions. +// Exercises Unfused or MEA decode path on CUDA. +TEST(AttentionTest, Attention4DCausalDecodeWithPastLowerRight) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 4; + int kv_sequence_length = 1; // new KV tokens + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 4; // total = 4 + 1 = 5 + + // clang-format off + std::vector q = {0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.3f, 0.2f, 0.1f, 0.4f}; // new key + std::vector v = {0.5f, 0.5f, 0.5f, 0.5f}; // new value + std::vector past_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector past_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Lower-right: q0 at pos 4 sees all 5 positions. scores=[0.6,0.48,0.495,0.78,0.28]*scale=0.5 already applied + std::vector y = {0.289363f, 0.265357f, 0.268203f, 0.331229f}; + // present = concat(past, new) in BNSH layout + std::vector present_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector present_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), past_key, past_value, + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + square (S_q=S_kv=4, no past) +// Upper-left == lower-right for square matrices. Verifies correctness on both paths. +// Exercises MEA or Unfused kernel depending on GPU capability. +TEST(AttentionTest, Attention4DCausalSquareNoPast) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 4; + int head_size = 4; + int kv_sequence_length = 4; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f, + 0.2f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Both alignments give identical result for square (no past). + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f, + 0.265821f, 0.240525f, 0.196925f, 0.296730f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 48640fa38aca2..1ab38fb1ea0f9 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -11,7 +11,7 @@ # ------------------------------------------------------------------------- """ -Shared utilities for ONNX Attention op (opset 23) tests. +Shared utilities for ONNX Attention op (opset 23/24) tests. Contains configuration, ONNX graph builders, reference implementation, and parity check helpers used by both GQA and MHA test modules. @@ -38,9 +38,6 @@ # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -# Number of values per parameter (compared to pipeline mode) -param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 - # When quick build is used, flash attention only supports head_size=128 quick_build = ", quick-build=" in get_build_info() @@ -71,14 +68,6 @@ torch.int8: TensorProto.INT8, } -TORCH_DTYPE_MAP = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "int8": torch.int8, - "int4": torch.uint8, -} - @dataclass class AttentionConfig: @@ -88,6 +77,7 @@ class AttentionConfig: q_num_heads: int kv_num_heads: int head_size: int + v_head_size: int = 0 # 0 means same as head_size; set explicitly for asymmetric Q/V head sizes is_causal: int = 0 past_kv_sequence_length: int = 0 softcap: float = 0.0 @@ -115,7 +105,7 @@ def create_attention_node_and_io( """ Create ONNX Attention op node and I/O definitions for testing. - ONNX Attention op (opset 23) inputs: + ONNX Attention op (opset 23/24) inputs: - 0: Q (query) - required - 1: K (key) - required - 2: V (value) - required @@ -135,6 +125,9 @@ def create_attention_node_and_io( else: # Prompt (no past KV cache) present_kv_seqlen = config.kv_sequence_length + # Effective v_head_size: defaults to head_size when not explicitly set + effective_v_head_size = config.v_head_size or config.head_size + if not config.kv_cache_type: config.kv_cache_type = { TensorProto.FLOAT16: "float16", @@ -168,7 +161,7 @@ def create_attention_node_and_io( while inputs and inputs[-1] == "": inputs.pop() - # ONNX Attention op attributes (opset 23) + # ONNX Attention op attributes (opset 23/24) node = helper.make_node( op_type="Attention", inputs=inputs, @@ -199,13 +192,14 @@ def create_attention_node_and_io( helper.make_tensor_value_info( "value", ort_type, - [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, effective_v_head_size], ), ] else: # 3D inputs: [batch, seq_len, hidden_size] q_hidden_size = config.q_num_heads * config.head_size kv_hidden_size = config.kv_num_heads * config.head_size + v_hidden_size = config.kv_num_heads * effective_v_head_size graph_input = [ helper.make_tensor_value_info( "query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size] @@ -214,7 +208,7 @@ def create_attention_node_and_io( "key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] ), helper.make_tensor_value_info( - "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + "value", ort_type, [config.batch_size, config.kv_sequence_length, v_hidden_size] ), ] @@ -263,10 +257,11 @@ def create_attention_node_and_io( # Shape: [batch, num_heads, past_seq_len, head_size] (4D BNSH format) if is_past: past_k_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, config.head_size] + past_v_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, effective_v_head_size] graph_input.extend( [ helper.make_tensor_value_info("past_key", cache_ort_type, past_k_shape), - helper.make_tensor_value_info("past_value", cache_ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", cache_ort_type, past_v_shape), ] ) @@ -276,16 +271,17 @@ def create_attention_node_and_io( # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, effective_v_head_size] if config.use_4d_bnsh: - output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size] + output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size] else: - output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * effective_v_head_size] graph_output = [ helper.make_tensor_value_info("output", ort_type, output_shape), helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), - helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", cache_ort_type, output_v_shape), ] if output_qk > 0: @@ -447,24 +443,26 @@ def attention_prompt_func( bind_tensor(io_binding, "nonpad_kv_seqlen", nonpad_kv_seqlen, device, TensorProto.INT64) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape for prompt (no past) present_seqlen = config.kv_sequence_length - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] # Determine dtype for cache tensors cache_dtype = out_dtype @@ -473,8 +471,8 @@ def attention_prompt_func( else: cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -565,28 +563,30 @@ def attention_past_func( bind_tensor(io_binding, "past_value", past_v_sliced, device, cache_ort_type) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape (past + new) present_seqlen = total_seq_len - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] cache_dtype = out_dtype - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -645,6 +645,9 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + # Corrected ordering per onnx/onnx#7865: QK → softcap → add bias/mask → softmax + # Softcap must be applied before mask so that -inf mask values are not + # squashed to finite -softcap, which would leak probability to masked positions. if softcap > 0: scores = (scores / softcap).tanh() * softcap diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index c4e3c1b19e85e..55f07666e8c6f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -98,16 +98,19 @@ def parity_check_gqa_prompt( ) v = torch.randn_like(k) * std - # --- Create attn_mask as boolean padding mask (simulating seqlens_k) --- + # --- Create attn_mask matching the ONNX model's expected shape --- attn_mask = None key_padding_mask = None if config.has_attn_mask: + total_seq = config.past_kv_sequence_length + config.kv_sequence_length + # 2D mask shape: [q_seq, total_seq] per ONNX spec (matches create_attention_graph_prompt) attn_mask = torch.ones( - config.batch_size, - config.kv_sequence_length, + config.q_sequence_length, + total_seq, device=device, dtype=torch.bool, ) + # key_padding_mask for PyTorch reference: [batch, kv_seq] key_padding_mask = torch.ones( config.batch_size, config.kv_sequence_length, @@ -115,6 +118,17 @@ def parity_check_gqa_prompt( dtype=torch.bool, ) + # --- Create nonpad_kv_seqlen tensor if needed (opset 24+) --- + nonpad_kv_seqlen = None + if config.has_nonpad_kv_seqlen: + # Each batch element has the full kv_sequence_length as valid (no padding) + nonpad_kv_seqlen = torch.full( + (config.batch_size,), + config.kv_sequence_length, + device=device, + dtype=torch.int64, + ) + # --- PyTorch Reference Path --- out_ref, _ = attention_ref( q=q, @@ -138,6 +152,7 @@ def parity_check_gqa_prompt( ep=ep, device=device, ort_type=ort_type, + nonpad_kv_seqlen=nonpad_kv_seqlen, ) if i == 0: first_out = out.clone() @@ -271,7 +286,7 @@ def parity_check_gqa_past( key_padding_mask = None if config.has_attn_mask: attn_mask = torch.ones( - config.batch_size, + config.q_sequence_length, total_seq_len, device=device, dtype=torch.bool, @@ -441,7 +456,7 @@ def parity_check_gqa_prompt_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -568,7 +583,7 @@ def parity_check_gqa_past_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_past_func( + out, _present_k, _present_v = attention_past_func( q=q, past_k=past_k, past_v=past_v, @@ -708,6 +723,9 @@ def gqa_prompt_padding_test_cases(): # Guard case: batch_size=4 != q_seq_len=1 (decode). This catches the original bug # where 2D mask was [batch, total_seq] instead of [q_seq, total_seq]. + # NOTE: is_causal=0 because per ONNX spec, is_causal with S_q!=S_kv and no past_key + # gives upper-left alignment (q[0] sees only kv[0]), which is not meaningful for decode. + # KV bounds are enforced by the attention mask instead. for mask_dims in mask_dims_options: config = AttentionConfig( batch_size=4, @@ -717,7 +735,7 @@ def gqa_prompt_padding_test_cases(): q_num_heads=8, kv_num_heads=2, head_size=128, - is_causal=1, + is_causal=0, has_attn_mask=True, attn_mask_dims=mask_dims, ) @@ -730,7 +748,9 @@ def gqa_past_padding_test_cases(): Generate test cases for ONNX Attention op GQA path with boolean padding masks in decoding phase. """ batches = [2] - seqs = [(1, 32)] + # past=31 + new=1 = total_seq=32, which satisfies MEA's bias alignment + # requirement (total_seq % 4 == 0) when attn_mask is present. + seqs = [(1, 31)] heads = [(8, 2)] h_sizes = [128] mask_dims_options = [2, 3, 4] @@ -863,22 +883,37 @@ def test_gqa_prompt_memory_efficient(self, name, config): # flash attention. -# TODO(titaiwang): Re-enable once PR #27851 merges (MEA supports past_key for GQA). -# Flash now rejects attn_mask (requires attn_mask==nullptr). GQA + bool mask + past_key -# has no runner until MEA supports past_key. See issue #27885. -@unittest.skip( - "Flash now rejects attn_mask. GQA + bool mask + past_key has no runner " - "until PR #27851 (MEA with past_key). See issue #27885." -) -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") -@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) -class TestONNXAttentionPaddingMaskGQA(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" + + @parameterized.expand(gqa_past_test_cases()) + def test_gqa_past_memory_efficient_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionPaddingMaskMEAGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks. - SKIPPED: Flash now requires attn_mask == nullptr. GQA + bool attn_mask + - past_key currently has no runner (Flash rejected, unfused doesn't support GQA, - MEA blocked by past_key != nullptr). Will be re-enabled when PR #27851 lands. + GQA + bool attn_mask + past_key uses the MEA decode path (Flash requires + attn_mask == nullptr). MEA handles bool masks via additive bias conversion. These tests verify that the boolean attn_mask is correctly converted to sequence lengths on GPU and that the attention computation respects the @@ -1011,7 +1046,7 @@ def parity_check_gqa_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1344,10 +1379,10 @@ def test_gqa_prompt_float_mask_4d(self): # ################################################################################################# -# Large Head Size Unfused GQA Tests (head_size=512, fixes #28195) +# Large Head Size Unfused Tests (head_size=512, fixes #28195) # # Flash Attention and Memory-Efficient Attention cap at head_size=256. For head_size=512 the -# op falls through to RunGqaUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, +# op falls through to RunUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, # eliminating fp16/bf16 overflow that caused NaNs (e.g. Gemma 4 global-attention layers). # # These tests deliberately disable both Flash and MEA to make the unfused fallback explicit @@ -1425,7 +1460,7 @@ class TestONNXAttentionGQALargeHeadUnfused(unittest.TestCase): Regression tests for GQA with head_size=512 via the unfused FP32-QK path (issue #28195). Flash Attention and MEA both cap at head_size=256. With both disabled the op routes - to RunGqaUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid + to RunUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid fp16/bf16 overflow that produced NaNs for Gemma 4 global-attention layers. Validates: no NaNs, numerical parity vs. PyTorch SDPA reference, for fp16 and bf16. @@ -1532,5 +1567,355 @@ def test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16(self): self.assertLess(out.float().max().item(), 1.0) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQAFloatMaskDecode(unittest.TestCase): + """ + Test GQA with float additive attention mask during decode using MEA. + + This exercises the MEA decode path with float additive masks — a scenario + that was a HARD ERROR before MEA+decode support (MEA was ineligible + when past_key was present, so this fell through to no kernel). + """ + + def test_gqa_past_float_mask_4d(self): + """Test GQA decode with 4D float additive mask via MEA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment for MEA) + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 8, 128, device=device, dtype=torch_type) * std + + past_k = torch.randn(2, 2, 31, 128, device=device, dtype=torch_type) * std + past_v = torch.randn_like(past_k) * std + + new_k = torch.randn(2, 1, 2, 128, device=device, dtype=torch_type) * std + new_v = torch.randn_like(new_k) * std + + total_seq_len = 32 # past(31) + new(1), satisfies MEA bias alignment (32 % 4 == 0) + + # Create additive mask with padding pattern: batch 0 has 28 valid past, batch 1 full + past_seqlens = torch.tensor([28, 31], dtype=torch.int32, device=device) + total_seqlens = past_seqlens + config.kv_sequence_length + + attn_mask = create_additive_mask_from_seqlens( + seqlens=total_seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=1, + num_heads=8, + device=device, + dtype=torch_type, + ) + + # Zero padded past positions for batch 0 + past_k[0, :, 28:, :] = 0 + past_v[0, :, 28:, :] = 0 + + # Reference: concat past + new, then compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + # Expand 4D mask to reference attn_bias [batch, heads, q_seq, total_seq] + attn_bias_ref = attn_mask + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_ort = out_ort.reshape(2, 1, 8, 128) + + # --- Verify present_k/v match concatenated reference --- + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # --- Verify output --- + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEAGQASoftcap(unittest.TestCase): + """ + Test softcap support for GQA via the Memory Efficient Attention path. + + Disables Flash Attention to force MEA. Verifies softcap with and without + attention mask for GQA (kv_num_heads != q_num_heads). + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + def test_mea_gqa_softcap_with_mask_prompt_fp16(self): + """MEA GQA softcap + causal mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_no_mask_prompt_fp16(self): + """MEA GQA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_with_mask_decode_fp16(self): + """MEA GQA softcap + causal mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_mask_ordering_no_leakage_prompt_fp16(self): + """Guard test: verify MEA GQA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the MHA ordering test, but with GQA + (kv_num_heads != q_num_heads) forced to MEA path. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + q_num_heads = 4 + kv_num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, q_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + # 4D mask: [batch, q_num_heads, q_seq, kv_seq] + attn_mask = torch.zeros(batch_size, q_num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA GQA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, q_num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping Flash GQA softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) +class TestONNXAttentionFlashGQASoftcap(unittest.TestCase): + """Test softcap support for GQA via the Flash Attention path. + + Flash does NOT accept explicit attn_mask for GQA — uses nonpad_kv_seqlen + (padding mask) instead. Tests verify softcap works correctly through Flash + with and without padding mask. + + Requires SM80+ (Flash Attention hardware requirement). + """ + + def test_flash_gqa_softcap_with_padding_mask_prompt_fp16(self): + """Flash GQA softcap + padding mask (nonpad_kv_seqlen), prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_nonpad_kv_seqlen=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_prompt_fp16(self): + """Flash GQA softcap without any mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_decode_fp16(self): + """Flash GQA softcap, decode phase (past KV), fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index abe180ee35787..a488e11e39d20 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -99,9 +99,14 @@ def parity_check_mha_prompt( attn_mask = None attn_bias_ref = None if config.has_attn_mask: - # Create additive mask (0 for valid, -inf for masked) - # For prompt without padding, create a causal-style or zero mask - seqlens = torch.full((config.batch_size,), config.kv_sequence_length, dtype=torch.int32, device=device) + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, config.kv_sequence_length * 3 // 4) + else: + mask_valid_len = config.kv_sequence_length + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) attn_mask = create_additive_mask_from_seqlens( seqlens=seqlens, total_seq_len=config.kv_sequence_length, @@ -127,6 +132,7 @@ def parity_check_mha_prompt( v=v, attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -146,9 +152,15 @@ def parity_check_mha_prompt( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -224,6 +236,65 @@ def parity_check_mha_past( ) new_v = torch.randn_like(new_k) * std + # Create attention mask if config requires one + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length + attn_mask = None + attn_bias_ref = None + if config.has_attn_mask: + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, total_seq_len * 3 // 4) + else: + mask_valid_len = total_seq_len + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) + + if config.attn_mask_type == "bool": + # Create boolean mask for ORT (True=attend, False=mask) + arange = torch.arange(total_seq_len, device=device) + if config.attn_mask_dims == 2: + mask_1d = arange < seqlens[0] + attn_mask = mask_1d.unsqueeze(0).expand(config.q_sequence_length, -1).contiguous() + else: + attn_mask = create_boolean_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + ) + # Create additive bias for PyTorch reference path + attn_bias_ref = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + else: + # Additive mask: same tensor for both ORT and reference + attn_mask = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + if config.attn_mask_dims == 2: + attn_bias_ref = ( + attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1) + ) + elif config.attn_mask_dims == 3: + attn_bias_ref = attn_mask.unsqueeze(0).expand(config.batch_size, -1, -1, -1) + else: + attn_bias_ref = attn_mask + # --- PyTorch Reference Path --- new_k_bnsh = new_k.transpose(1, 2) new_v_bnsh = new_v.transpose(1, 2) @@ -236,7 +307,9 @@ def parity_check_mha_past( q=q, k=full_k_bsnh, v=full_v_bsnh, + attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -250,7 +323,7 @@ def parity_check_mha_past( new_k=new_k, new_v=new_v, config=config, - attn_mask=None, + attn_mask=attn_mask, ep=ep, device=device, ort_type=ort_type, @@ -258,9 +331,15 @@ def parity_check_mha_past( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -367,10 +446,11 @@ def parity_check_mha_prompt_with_attn_bias( v=v, attn_bias=attn_bias_ref, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -698,10 +778,11 @@ def parity_check_mha_prompt_with_bool_mask( v=v, key_padding_mask=key_padding_mask, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -866,6 +947,110 @@ def test_mha_past_fp32(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEA(unittest.TestCase): + """Test ONNX Attention op MHA path — decoding with KV cache via Memory Efficient Attention. + + Explicitly forces MEA by disabling Flash Attention. This verifies that the + MEA decode path works correctly for MHA (kv_num_heads == q_num_heads). + """ + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea(self, name, config): + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFP32(unittest.TestCase): + """Test MHA decode via MEA with fp32 dtype.""" + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea_fp32(self, name, config): + config.kv_cache_type = "float32" + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEABoolMask(unittest.TestCase): + """Test MHA decode via MEA with boolean attention mask (converted to additive bias).""" + + def test_mha_past_bool_mask_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=2, + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFloatMask(unittest.TestCase): + """Test MHA decode via MEA with float additive attention mask.""" + + def test_mha_past_float_mask_4d_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") class TestONNXAttentionMHAAttnBias(unittest.TestCase): """ @@ -998,7 +1183,7 @@ def parity_check_mha_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1249,116 +1434,693 @@ def test_mha_unfused_fp16(self, name, config): atol=atol["fp16"], ) - -# ################################################################################################# -# Broadcast Mask (1,1,q,kv) Tests -# ################################################################################################# + def test_mha_unfused_decode_fp32(self): + """Test unfused decode with fp32 (both Flash and MEA disabled).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") -class TestONNXAttentionMHABroadcastMask(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping unfused softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION": "1"}) +class TestONNXAttentionMHAUnfusedSoftcap(unittest.TestCase): """ - Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + Test softcap support in the unfused attention kernel. - This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that - the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + Disables Flash and MEA to force the unfused path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + before softmax, matching the reference implementation. """ - def test_mha_broadcast_mask_additive(self): - """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + def test_unfused_softcap_prompt_fp16(self): + """Test softcap on unfused path during prompt (fp16).""" config = AttentionConfig( batch_size=2, - q_sequence_length=16, - kv_sequence_length=16, - q_num_heads=8, - kv_num_heads=8, - head_size=128, - is_causal=0, - has_attn_mask=True, - attn_mask_dims=4, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, attn_mask_type="additive", - broadcast_mask_batch=True, - broadcast_mask_heads=True, ) - - torch.manual_seed(0) - device = "cuda" - torch_type = torch.float16 - - q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - v = torch.randn_like(k) * 0.2 - - # Create (1,1,q,kv) additive mask: lower-triangular causal pattern - mask_filter = float(torch.finfo(torch_type).min) - mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) - for i in range(16): - mask_2d[i, i + 1 :] = mask_filter - attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) - - # Reference: expand to full (B, H, Q, K) - attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() - out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) - - # ORT path - out_ort, _, _ = attention_prompt_func( - q=q, - k=k, - v=v, + parity_check_mha_prompt( config=config, - attn_mask=attn_mask, ep="CUDAExecutionProvider", - device=device, + device="cuda", + torch_type=torch.float16, ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], ) - out_ort = out_ort.reshape(2, 16, 8, 128) - - out_np = out_ort.float().detach().cpu().numpy() - out_ref_np = out_ref.float().detach().cpu().numpy() - numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) - - -# ################################################################################################# -# 2D Mask Broadcast Regression Test -# ################################################################################################# + def test_unfused_softcap_decode_fp16(self): + """Test softcap on unfused path during decode (fp16).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") -class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): - """ - Regression test for 2D mask [q_seq, total_seq] broadcast correctness. - - Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts - over batch and heads. This test uses batch_size > q_seq with a non-uniform - mask (different values per row) to verify correct broadcast behavior. - - The old bug indexed the 2D mask by batch index instead of query position, - causing OOB reads when batch_size > q_seq. - """ + def test_unfused_softcap_prompt_fp32(self): + """Test softcap on unfused path during prompt (fp32).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) - def test_2d_additive_mask_batch_gt_qseq(self): - """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + def test_unfused_softcap_with_mask_prompt_fp16(self): + """Test softcap + float mask on unfused path — verifies spec-correct ordering (softcap→mask→softmax).""" config = AttentionConfig( - batch_size=4, - q_sequence_length=2, + batch_size=2, + q_sequence_length=8, kv_sequence_length=8, q_num_heads=4, kv_num_heads=4, head_size=64, is_causal=0, + softcap=2.0, has_attn_mask=True, - attn_mask_dims=2, + attn_mask_dims=4, attn_mask_type="additive", ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) - torch.manual_seed(42) - device = "cuda" - torch_type = torch.float16 - mask_filter_value = torch.finfo(torch_type).min - - q = ( - torch.randn( - config.batch_size, + def test_unfused_softcap_with_mask_decode_fp16(self): + """Test softcap + float mask on unfused decode — verifies spec-correct ordering.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: fp32 variants --- + + def test_unfused_softcap_with_mask_prompt_fp32(self): + """Test softcap + additive mask on unfused prompt (fp32). + + The helper auto-creates a partial mask (3/4 valid positions) when softcap > 0, + ensuring the mask has both 0.0 and -inf values to exercise the softcap→bias ordering. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=False, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + def test_unfused_softcap_with_mask_decode_fp32(self): + """Test softcap + additive mask on unfused decode (fp32). + + Decode with past KV cache: total_seq=32, ~24 valid positions, 8 masked. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + # --- Partial masking: different mask dimensionalities --- + + def test_unfused_softcap_with_mask_2d_prompt_fp16(self): + """Test softcap + 2D additive mask on unfused prompt. + + A 2D mask [q_seq, kv_seq] broadcasts across batch and heads. + This tests the 2D mask indexing path in the unfused kernel. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_unfused_softcap_with_mask_3d_prompt_fp16(self): + """Test softcap + 3D additive mask on unfused prompt. + + A 3D mask [heads, q_seq, kv_seq] broadcasts across batch dimension. + This tests the 3D mask broadcast path which has its own handling branch. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=3, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: larger sequence (different absolute mask boundary) --- + + def test_unfused_softcap_with_mask_longer_seq_prompt_fp16(self): + """Test softcap + mask with a longer sequence (kv_seq=16). + + With kv_seq=16, mask_valid_len=12 (3/4). This exercises a different absolute + mask boundary compared to the kv_seq=8 tests (valid_len=6) and provides + a wider range of softcapped logit values interacting with the mask. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify softcap + mask ordering prevents attention leakage. + + This test PROVES the ordering matters and would FAIL if someone reverts + to the wrong ordering (mask before softcap). + + Setup: Create a mask where some KV positions are -inf (masked). Place + a distinctive 'poison' value (1000.0) in V at masked positions. With + correct ordering (softcap → mask → softmax), masked positions get + -inf after bias addition → zero attention → output uncontaminated. + With wrong ordering (mask → softcap → softmax), softcap(-inf) = -softcap + (finite) → nonzero attention → output contaminated by poison values. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + # Only the first 4 KV positions are valid; last 4 are masked (-inf) + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + # If ordering is wrong, poison values leak into output producing extreme values. + # Valid output range with std=0.2 inputs and softcap=2.0 is roughly [-10, 10]. + # Any element > 50 indicates attention leakage to the poison=1000 positions. + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means softcap is applied AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.01, atol=0.01) + + def test_softcap_mask_ordering_no_leakage_decode(self): + """Guard test for decode (past KV) path: softcap + mask ordering prevents leakage. + + Same poison-value technique as the prompt test, but exercises the decode + code path with past KV cache. Masked positions in the past cache should + receive zero attention with correct ordering. + """ + batch_size = 1 + q_seq = 1 # decode: single token + kv_seq = 1 + past_kv_seq = 15 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + total_kv_seq = past_kv_seq + kv_seq # 16 total + valid_kv_len = 8 # Only first 8 of 16 positions are valid + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + past_kv_sequence_length=past_kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Past KV with poison in masked positions + past_k = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + past_v = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + poison_value = 1000.0 + past_v[:, :, valid_kv_len:, :] = poison_value + + # Mask: 0.0 for first valid_kv_len positions, -inf for rest + attn_mask = torch.zeros(batch_size, num_heads, q_seq, total_kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime via attention_past_func + out, _, _ = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=k, + new_v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected in decode path: max |output| = {max_abs:.1f}. " + f"Softcap must be applied BEFORE mask (per onnx/onnx#7865).", + ) + + +# ################################################################################################# +# Asymmetric Head Size Regression Test (MEA → unfused fallback) +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping asymmetric head size tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAAsymmetricHeadSize(unittest.TestCase): + """ + Regression test: MEA gracefully falls back to unfused when head_size != v_head_size + with past_key present (decode phase). + + Without the eligibility guard in ComputeInternal, this configuration would select + MEA which then crashes with ORT_ENFORCE because LaunchConcatNewToPastKV requires + head_size == v_head_size. The guard skips MEA and falls back to unfused attention. + + Uses MHA path (kv_num_heads == q_num_heads) because the GQA path has no unfused + fallback (returns NOT_IMPLEMENTED). + """ + + def test_mha_past_asymmetric_v_head_size(self): + """Verify decode with head_size=128, v_head_size=96 doesn't crash (falls to unfused).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=128, + v_head_size=96, + is_causal=1, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + + # Past KV in BNSH: K uses head_size=128, V uses v_head_size=96 + past_k = torch.randn(2, 4, 32, 128, device=device, dtype=torch_type) * std + past_v = torch.randn(2, 4, 32, 96, device=device, dtype=torch_type) * std + + new_k = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + new_v = torch.randn(2, 1, 4, 96, device=device, dtype=torch_type) * std + + # PyTorch reference: concat past + new, compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + + # ORT path — should fall back to unfused (not crash in MEA) + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=None, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + # Reshape output: [B, q_seq, q_num_heads * v_head_size] → [B, q_seq, q_num_heads, v_head_size] + out_ort = out_ort.reshape(2, 1, 4, 96) + + # Verify present_k and present_v + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # Verify output + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# Broadcast Mask (1,1,q,kv) Tests +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") +class TestONNXAttentionMHABroadcastMask(unittest.TestCase): + """ + Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + + This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that + the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + """ + + def test_mha_broadcast_mask_additive(self): + """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=128, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + broadcast_mask_batch=True, + broadcast_mask_heads=True, + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + v = torch.randn_like(k) * 0.2 + + # Create (1,1,q,kv) additive mask: lower-triangular causal pattern + mask_filter = float(torch.finfo(torch_type).min) + mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) + for i in range(16): + mask_2d[i, i + 1 :] = mask_filter + attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) + + # Reference: expand to full (B, H, Q, K) + attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(2, 16, 8, 128) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# 2D Mask Broadcast Regression Test +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") +class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): + """ + Regression test for 2D mask [q_seq, total_seq] broadcast correctness. + + Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts + over batch and heads. This test uses batch_size > q_seq with a non-uniform + mask (different values per row) to verify correct broadcast behavior. + + The old bug indexed the 2D mask by batch index instead of query position, + causing OOB reads when batch_size > q_seq. + """ + + def test_2d_additive_mask_batch_gt_qseq(self): + """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + config = AttentionConfig( + batch_size=4, + q_sequence_length=2, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + mask_filter_value = torch.finfo(torch_type).min + + q = ( + torch.randn( + config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size, @@ -1490,6 +2252,285 @@ def test_2d_bool_mask_batch_gt_qseq(self): numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEASoftcap(unittest.TestCase): + """ + Test softcap support in the Memory Efficient Attention (MEA) kernel. + + Disables Flash Attention to force the MEA path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + in MEA, matching the reference implementation. + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + # --- P0: MEA softcap+mask (MHA) --- + + def test_mea_softcap_with_mask_prompt_fp16(self): + """MEA softcap + additive mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_with_mask_decode_fp16(self): + """MEA softcap + additive mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq = 31+1 = 32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P0: MEA softcap-only (no mask) --- + + def test_mea_softcap_no_mask_prompt_fp16(self): + """MEA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_no_mask_decode_fp16(self): + """MEA softcap without explicit mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P1: MEA softcap ordering poison test --- + + def test_mea_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify MEA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the unfused ordering test, but forces the + MEA path. Proves MEA correctly applies softcap before mask addition. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_cuda_device(80), "Flash Attention requires Ampere or higher GPU, skipping tests.") +class TestONNXAttentionFlashSoftcap(unittest.TestCase): + """ + Test softcap support via Flash Attention path. + + Does NOT disable Flash or MEA — lets the dispatch cascade choose naturally. + On Ampere+ with fp16 and head_size<=256, this should route to Flash Attention. + """ + + def test_flash_softcap_prompt_fp16(self): + """Flash Attention softcap, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_decode_fp16(self): + """Flash Attention softcap, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_with_mask_prompt_fp16(self): + """Flash Attention softcap + mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # NOTE: GQA fully-masked batch fix (ZeroOutputForFullyMaskedBatches) is validated by # C++ test Attention_NonPadKVSeqLen_AllMasked_FP16_GQA. Python graph-level test omitted # because the fix is a CUDA kernel in the MEA path — a CPU-only test cannot validate it, diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py index a6a115bb12213..6b3f6d1c3ff34 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py @@ -460,16 +460,22 @@ def cpu_test_cases(): def cuda_fp16_test_cases(): - """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly.""" + """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=0) - yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=1) def cuda_fp32_test_cases(): """CUDA fp32: MHA only. GQA requires fp16/bf16, and flash attention requires fp16/bf16. - fp32 MHA uses the unfused attention_bias fallback path.""" + fp32 MHA uses the unfused attention_bias fallback path. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_MHA_CASES, is_causal=0) - yield from _make_test_params(_MHA_CASES, is_causal=1) # ################################################################################################# @@ -975,5 +981,71 @@ def test_nonpad_with_bool_mask_cuda_fp16( numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp16"], atol=atol["fp16"]) +class TestCausalTensorScatterRejected(unittest.TestCase): + """Test that is_causal=1 + TensorScatter decode (S_q != S_kv, no past) is rejected. + + Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + q[0] sees only kv[0] — not meaningful for autoregressive generation. + + The dispatch guard should return NOT_IMPLEMENTED for this combination. + Models should use is_causal=0 for TensorScatter decode. + """ + + @unittest.skipUnless("CUDAExecutionProvider" in get_available_providers(), "CUDA not available") + def test_is_causal_with_tensorscatter_no_past_rejected(self): + """Verify NOT_IMPLEMENTED is raised for is_causal=1 + TensorScatter + S_q != S_kv.""" + batch_size = 1 + q_seq_len = 1 + total_kv_seq_len = 8 + q_num_heads = 2 + kv_num_heads = 2 + head_size = 32 + + # Build model with is_causal=1 (the rejected combination) + model_bytes = build_tensorscatter_attention_graph( + batch_size=batch_size, + total_kv_seq_len=total_kv_seq_len, + q_seq_len=q_seq_len, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + ort_type=TensorProto.FLOAT16, + is_causal=1, + ) + + sess_opts = SessionOptions() + session = InferenceSession(model_bytes, sess_opts, providers=["CUDAExecutionProvider"]) + + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + key_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + value_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + new_k = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + new_v = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + write_indices = numpy.array([4], dtype=numpy.int64) + query = numpy.random.randn(batch_size, q_seq_len, q_hidden).astype(numpy.float16) + nonpad_kv_seqlen = numpy.array([5], dtype=numpy.int64) + + feeds = { + "key_cache": key_cache, + "value_cache": value_cache, + "new_k": new_k, + "new_v": new_v, + "write_indices": write_indices, + "query": query, + "nonpad_kv_seqlen": nonpad_kv_seqlen, + } + + with self.assertRaises(Exception) as ctx: + session.run(None, feeds) + + error_msg = str(ctx.exception) + self.assertTrue( + "NOT_IMPLEMENTED" in error_msg or "nonpad_kv_seqlen" in error_msg, + f"Expected NOT_IMPLEMENTED error for is_causal + TensorScatter decode, got: {error_msg}", + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 5f8871d71c80a..5e8a6532e974d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -42,14 +42,9 @@ "^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen // TODO: support qk_matmul_output modes beyond kQK in Attention-cuda (see issue #27712) - // Tests combining qk_matmul with softcap need unfused-path softcap support (deferred). - "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - // softcap + diff head sizes (head_size != v_head_size) blocks Flash, falls to unfused which lacks softcap - "^test_attention_3d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_attn_mask_bool_cuda", // bool mask not supported in Attention-cuda - "^test_attention_4d_attn_mask_bool_4d_cuda", // bool mask not supported in Attention-cuda + // Tests combining qk_matmul with softcap need unfused-path qk_matmul support (deferred). + "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported + "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported "^test_attention_3d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cuda", // QK matmul + bias not supported in Attention-cuda @@ -57,27 +52,6 @@ "^test_attention_4d_with_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_3d_with_past_and_present_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda - // is_causal=Truen && q_seq_len != kv_seq_len not supported in Attention-cuda - "^test_attention_3d_causal_cuda", - "^test_attention_3d_diff_heads_sizes_causal_cuda", - "^test_attention_4d_attn_mask_3d_causal_cuda", - "^test_attention_4d_attn_mask_4d_causal_cuda", - "^test_attention_4d_causal_cuda", - "^test_attention_4d_diff_heads_sizes_causal_cuda", - // GQA Attention-cuda does not support fp16 and 4d QKV - "^test_attention_4d_gqa_with_past_and_present_fp16_cuda", // 4d QKV - "^test_attention_4d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_softcap_cuda", // fp32 - "^test_attention_4d_gqa_scaled_cuda", // fp32 - "^test_attention_4d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_attn_mask_cuda", // fp32 - "^test_attention_3d_gqa_causal_cuda", // fp32 - "^test_attention_3d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_scaled_cuda", // fp32 - "^test_attention_3d_gqa_softcap_cuda", // fp32 - "^test_attention_3d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_attn_mask_cuda", // fp32 - "^test_attention_4d_gqa_causal_cuda", // fp32 "^test_tensorscatter*", // TensorScatter(24) not implemented "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes