diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index abdcc81586909..75b43c363f371 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -655,7 +655,8 @@ Do not modify directly.* |ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||12|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**

or

*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)| +|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**

or

*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|24+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)| +|||23|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 6c08d7fbd9b3f..8cccc2f1a725c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -132,6 +132,16 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block); +// BxNxSxH => BxSxNxH +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); + +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); + +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block); + template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index e7177987fa2d1..e5f1ffca251b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -393,6 +393,26 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +// BxNxSxH => BxSxNxH (BNSH to BSNH) — reverse of Transpose_BSNH_to_BNSH. +// Reuses the existing TransposeCtx kernel which does exactly this transformation. +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + +Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 1b68d70617744..29bb4fba6a09a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -258,6 +258,33 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; + // Bias stride alignment check: route to the unaligned kernel when bias strides + // don't satisfy the aligned kernel's kAlignmentQ requirement. + // + // kAlignmentQ is template-dependent (kernel_forward.h:414): + // isAligned=true: kAlignmentQ = DefaultConfig::kAlignmentA (8 for fp16/bf16 SM75+) + // isAligned=false: kAlignmentQ = GemmType::kMinimumAlignment (4 for fp16/bf16 SM75+) + // So check_supported (line 632) enforces DIFFERENT thresholds per path. + // + // The ONNX Attention kernel (core/providers/cuda/llm/attention.cc) gates MEA eligibility + // at kMinimumAlignment (4), allowing strides like 12 that the unaligned kernel handles. + // Without this check, such inputs dispatch to the aligned kernel where 12%8≠0 crashes. + // Contrib MHA gates at 4*sizeof(T)=8 for fp16, making this check redundant there. + if (params.attn_bias != nullptr) { + int num_keys = params.kv_sequence_length; + int num_queries = params.sequence_length; + int bias_strideM = num_keys; + // Broadcast dimensions use stride=0, which satisfies any alignment (0 % N == 0). + int bias_strideH = params.broadcast_attn_bias_dim_1 ? 0 : num_queries * num_keys; + int bias_strideB = params.broadcast_attn_bias_dim_0 + ? 0 + : ((params.broadcast_attn_bias_dim_1 ? 1 : params.num_heads) * num_queries * num_keys); + is_aligned = is_aligned && + bias_strideM % AlignedAK::kAlignmentQ == 0 && + (params.num_heads <= 1 || bias_strideH % AlignedAK::kAlignmentQ == 0) && + (params.batch_size <= 1 || bias_strideB % AlignedAK::kAlignmentQ == 0); + } + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 961c80748d228..d2f74386b701e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -1147,6 +1147,24 @@ template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>( GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data); #endif +// Explicit instantiations for cross-TU usage by core/providers/cuda/llm/attention.cc +template Status LaunchUngroup<__half>( + const GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block); +template Status LaunchUngroup<__nv_bfloat16>( + const GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block); + template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 78b061837e402..125ab8f76132c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -122,6 +122,16 @@ struct GQABufferRequirements { } }; +template +// Also used by ONNX Attention (core/providers/cuda/llm/attention.cc) for GQA head expansion in MEA path. +Status LaunchUngroup(const GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block); + Status LaunchGetSequenceLengths( const int* total_seq_lens_minus_one, int* past_seq_lens, diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h index e7df1a078472a..efc908603a990 100644 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.h +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -27,7 +27,8 @@ inline Status ComputeOutputShapeForAttention( TensorShape& y_shape, TensorShape& present_key_shape, TensorShape& present_value_shape, - TensorShape& output_qk_shape) { + TensorShape& output_qk_shape, + bool skip_nonpad_data_validation = false) { ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, "Q, K, and V inputs must not be null"); int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); @@ -90,6 +91,13 @@ inline Status ComputeOutputShapeForAttention( parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); + + // Validate mask second-to-last dim matches q_sequence_length (same check as 4D path). + // For 2D mask [A, B]: A must equal q_seq. For 3D mask [A, B, C]: B must equal q_seq. + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[1], + "inconsistent q_sequence_length (between attn_mask and Q)"); + parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; @@ -115,11 +123,14 @@ inline Status ComputeOutputShapeForAttention( parameters.has_nonpad_kv_seqlen = true; parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data(); // Validate each value is in [0, total_sequence_length]. - for (int i = 0; i < parameters.batch_size; ++i) { - ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 && - parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length, - "nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i], - " is out of range [0, ", parameters.total_sequence_length, "]"); + // Skip when data is on GPU (CUDA path sets skip_nonpad_data_validation=true). + if (!skip_nonpad_data_validation) { + for (int i = 0; i < parameters.batch_size; ++i) { + ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 && + parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length, + "nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i], + " is out of range [0, ", parameters.total_sequence_length, "]"); + } } } else { parameters.has_nonpad_kv_seqlen = false; @@ -127,6 +138,11 @@ inline Status ComputeOutputShapeForAttention( } ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention."); + // TODO: The ONNX spec allows attn_mask last dim to be shorter than total_sequence_length, + // with positions beyond the mask padded with -inf. Currently we enforce exact match. + // To support: change == to <=, allocate padded buffer, fill remainder with -inf. + // See ONNX spec: 'The last dimension can also be shorter than total_sequence_length + // and will be padded to total_sequence_length with negative infinity.' ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); ORT_ENFORCE(attn_mask == nullptr || diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index bf6fcc7ccf0a8..92e1e5108cb8a 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1592,9 +1592,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); // Opset 23. -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization); @@ -1633,6 +1633,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 24. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention); #endif @@ -2671,9 +2674,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2711,6 +2714,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 24 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 8fc7a49827087..4f7fd6a664bdc 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" @@ -20,10 +19,11 @@ namespace onnxruntime { namespace cuda { #define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Attention, \ kOnnxDomain, \ 23, \ + 23, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -36,11 +36,26 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(BFloat16) +#define REGISTER_KERNEL_TYPED_24(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kOnnxDomain, \ + 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", BuildKernelDefConstraints()), \ + Attention); + +REGISTER_KERNEL_TYPED_24(float) +REGISTER_KERNEL_TYPED_24(MLFloat16) +REGISTER_KERNEL_TYPED_24(BFloat16) + template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; - // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs. - // The dimension is not yet known. If not specified, the inputs is assumed to be 4D. kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); @@ -52,589 +67,1100 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask || qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap || qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax, - "qk_matmul_output_mode must be 0, 1, 2, or 3."); - // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. + "qk_matmul_output_mode must be one of: kNone(-1), kQK(0), kQKMask(1), kQKSoftCap(2), kQKSoftMax(3)."); scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); softcap_ = info.GetAttrOrDefault("softcap", 0.0f); softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0)); ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified"); + + const auto* kernel_options = this->GetAttentionKernelOptions(); + disable_flash_attention_ = std::is_same::value || !kernel_options->UseFlashAttention(); + disable_memory_efficient_attention_ = !kernel_options->UseEfficientAttention(); } +// ============================================================================ +// Transpose helpers: eliminate repeated if-constexpr type-switch blocks. +// T is the ORT type (MLFloat16, BFloat16, float). The helpers map T to the +// corresponding CUDA type via ToCudaType::MappedType and forward to the +// overloaded Transpose functions in contrib_ops. +// ============================================================================ + template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* Q = context->Input(0); - const Tensor* K = context->Input(1); - const Tensor* V = context->Input(2); - const Tensor* attn_mask = context->Input(3); - const Tensor* past_key = context->Input(4); - const Tensor* past_value = context->Input(5); - const Tensor* nonpad_kv_seqlen = context->Input(6); // optional, Opset 24 +static Status TransposeBNSHtoBSNH(int batch_size, int sequence_length, + int num_heads, int head_size, + const void* input, void* output, + cudaStream_t stream, int max_threads_per_block) { + using CudaT = typename ToCudaType::MappedType; + return onnxruntime::contrib::cuda::Transpose_BNSH_to_BSNH( + batch_size, sequence_length, num_heads, head_size, + reinterpret_cast(input), + reinterpret_cast(output), + stream, max_threads_per_block); +} - attention_helper::AttentionParameters parameters; - TensorShape y_shape; - TensorShape present_key_shape; - TensorShape present_value_shape; - TensorShape output_qk_shape; +template +static Status TransposeBSNHtoBNSH(int batch_size, int sequence_length, + int num_heads, int head_size, + const void* input, void* output, + cudaStream_t stream, int max_threads_per_block) { + using CudaT = typename ToCudaType::MappedType; + return onnxruntime::contrib::cuda::Transpose_BSNH_to_BNSH( + batch_size, sequence_length, num_heads, head_size, + reinterpret_cast(input), + reinterpret_cast(output), + stream, max_threads_per_block); +} - ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( - Q, - K, - V, - attn_mask, - past_key, - past_value, - nonpad_kv_seqlen, - is_causal_, - softcap_, - softmax_precision_, - qk_matmul_output_mode_, - kv_num_heads_, - q_num_heads_, - scale_, - parameters, - y_shape, - present_key_shape, - present_value_shape, - output_qk_shape) - .IsOK(), - "Output shapes for Attention could not be computed."); +// ============================================================================ +// ConvertAttnMaskToBias: shared helper for mask→additive bias conversion. +// Used by both Flash (nonpad+mask) and MEA paths to avoid code duplication. +// Converts bool masks to additive bias (true→0, false→mask_filter_value), +// passes float masks through directly, and sets broadcast flags from mask shape. +// ============================================================================ +template +Status Attention::ConvertAttnMaskToBias( + OpKernelContext* context, + const Tensor* attn_mask, + cudaStream_t cuda_stream, + int max_threads_per_block, + IAllocatorUniquePtr& converted_mask_buffer, + const void*& attn_bias_data, + bool& broadcast_bias_dim_0, + bool& broadcast_bias_dim_1) const { + if (attn_mask->IsDataType()) { + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + int64_t num_elements = attn_mask->Shape().Size(); + converted_mask_buffer = GetScratchBuffer( + num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + float mask_filter_value = static_cast(std::numeric_limits::lowest()); + ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( + attn_mask->Data(), + reinterpret_cast(converted_mask_buffer.get()), + num_elements, mask_filter_value, cuda_stream, + max_threads_per_block)); + attn_bias_data = converted_mask_buffer.get(); + } else { + attn_bias_data = attn_mask->Data(); + } - Tensor* Y = context->Output(0, y_shape); - Tensor* present_key = context->Output(1, present_key_shape); - Tensor* present_value = context->Output(2, present_value_shape); - Tensor* output_qk = context->Output(3, output_qk_shape); + size_t mask_dims = attn_mask->Shape().NumDimensions(); + auto dims = attn_mask->Shape().GetDims(); + if (mask_dims == 2) { + broadcast_bias_dim_0 = true; + broadcast_bias_dim_1 = true; + } else if (mask_dims == 3) { + broadcast_bias_dim_0 = true; + broadcast_bias_dim_1 = dims[0] == 1; + } else { + broadcast_bias_dim_0 = dims[0] == 1; + broadcast_bias_dim_1 = dims[1] == 1; + } + return Status::OK(); +} - // To reuse the existing attention-cuda implementation in contrib ops, - // map the parameters to contribop_parameters (MHA). - onnxruntime::contrib::AttentionParameters contribop_parameters; +// ============================================================================ +// RunFlashAttention: Direct flash attention kernel call +// ============================================================================ +// +// Flash Attention dispatch paths: +// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache +// Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache +// - Supports: bool mask -> seqlens_k, no mask -> fill past_seq_len +// - 4D BNSH: transposes Q/K/V to BSNH before kernel +// Path 3: no past, no mask (prompt) -> mha_fwd +// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, +// (no mask OR bool mask + past OR nonpad_kv_seqlen without mask) +// Note: softcap and softmax_precision are early-rejected before the cascade. +// Note: nonpad_kv_seqlen + attn_mask is supported but routes to MEA/unfused, +// not Flash (Flash has no bias parameter for this combination). +// +// PERFORMANCE NOTE: ONNX Attention's internal-cache decode path (past_key/past_value) +// is ~15-30% slower than contrib GQA's decode path for grouped-query attention workloads. +// When using external KV cache via TensorScatter + nonpad_kv_seqlen (opset 24), the +// copy overhead (point 1) is eliminated. The remaining ~5-15% gap is from the missing +// XQA kernel (point 2). +// +// The internal-cache overhead comes from: +// +// 1. No past_present_share_buffer: The ONNX Attention spec requires past_key/value +// shape = (B, H, past_seq, head_size) and present_key/value shape = +// (B, H, total_seq, head_size) where total_seq = past_seq + kv_seq. +// Since past and present have different shapes, they cannot share the same buffer. +// Contrib GQA allows past and present to be the same tensor (in-place append), +// eliminating the memset + strided copy overhead (~67MB per decode step for typical LLM). +// This overhead does NOT apply to the external-cache path (TensorScatter + +// nonpad_kv_seqlen), which bypasses past/present entirely. +// +// 2. No XQA kernel: GQA's specialized XQA decode kernel (xqa_loader.h) requires +// past_present_share_buffer to function. Since ONNX Attention cannot share buffers +// (see point 1), XQA is fundamentally incompatible with this op's spec design. +// This accounts for the remaining ~5-15% gap even on the external-cache path. +// +// 3. These are spec-level limitations, not implementation gaps. For production LLM +// inference, the external-cache path (TensorScatter + nonpad_kv_seqlen) is +// recommended and achieves near-parity with contrib GQA performance. +// +template +Status Attention::RunFlashAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const { +#if USE_FLASH_ATTENTION + auto& device_prop = GetDeviceProp(); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + const bool is_bf16 = std::is_same::value; + const bool is_bsnh = parameters.transpose_output; // 3D inputs → BSNH - // QKV format: Determine based on input dimensions - // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH - // transpose_output is true for 3D inputs, false for 4D inputs - if (!parameters.transpose_output) { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - contribop_parameters.is_output_bnsh = true; - } else { - // 3D inputs in BSNH format (will be transposed) - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; - contribop_parameters.is_output_bnsh = false; + // --- Common buffer allocation --- + size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size( + parameters.q_sequence_length, parameters.batch_size, parameters.q_num_heads); + + auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = + onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.q_sequence_length, + parameters.total_sequence_length, parameters.q_num_heads, + parameters.head_size, device_prop.multiProcessorCount); + + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + + if (softmax_lse_accum_bytes > 0) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0, + softmax_lse_accum_bytes, cuda_stream)); + } + if (out_accum_bytes > 0) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(out_accum_buffer.get(), 0, + out_accum_bytes, cuda_stream)); } - // Check if this is Group Query Attention (GQA) - const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; + // --- Transpose Q from BNSH to BSNH (flash always expects Q as BSNH) --- + const void* q_data = Q->Data(); + IAllocatorUniquePtr q_bsnh_buffer; + if (!is_bsnh) { + size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.head_size; + q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.q_sequence_length, + parameters.q_num_heads, parameters.head_size, + Q->Data(), q_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + q_data = q_bsnh_buffer.get(); + } - if (is_gqa) { - // Use GQA path with Flash Attention or Memory Efficient Attention - // GQA only supports float16 and bfloat16 types - if constexpr (std::is_same::value) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "GQA in Attention op (CUDA) does not support float32. " - "Please use float16 or bfloat16."); + // Flash outputs BSNH. If Y expects BNSH, write to scratch then transpose. + void* out_data = Y->MutableData(); + IAllocatorUniquePtr out_bsnh_buffer; + if (!is_bsnh) { + size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.v_head_size; + out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_data = out_bsnh_buffer.get(); + } + + bool present_kv_already_populated = false; + + // --- Path 1: nonpad_kv_seqlen (opset 24 external KV cache) --- + if (nonpad_kv_seqlen != nullptr) { + ORT_ENFORCE(parameters.past_sequence_length == 0, + "RunFlashAttention with nonpad_kv_seqlen requires K/V to be the full cache " + "(past_sequence_length must be 0, got ", + parameters.past_sequence_length, ")."); + + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( + nonpad_kv_seqlen->Data(), + seqlens_k_buffer.get(), + parameters.batch_size, + parameters.total_sequence_length, + cuda_stream, + device_prop.maxThreadsPerBlock)); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, cuda_stream, + const_cast(q_data), + const_cast(static_cast(K->Data())), + const_cast(static_cast(V->Data())), + /*k=*/nullptr, /*v=*/nullptr, + out_data, + softmax_lse_buffer.get(), + const_cast(static_cast(seqlens_k_buffer.get())), + /*rotary_cos=*/nullptr, /*rotary_sin=*/nullptr, + /*cache_batch_idx=*/nullptr, /*leftpad_k=*/nullptr, + /*head_sink=*/nullptr, /*block_table=*/nullptr, + parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads, + parameters.head_size, + parameters.q_sequence_length, parameters.kv_sequence_length, + /*seqlen_k_new=*/0, /*rotary_dim=*/0, + parameters.scale, parameters.softcap, + parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false, + /*past_bsnh=*/is_bsnh, + static_cast(num_splits), + softmax_lse_accum_buffer.get(), out_accum_buffer.get(), + /*local_window_size=*/-1, /*is_rotary_interleaved=*/false, + /*is_packed_qkv=*/false)); + } + // --- Path 2: Decode with past KV cache --- + else if (past_key != nullptr) { + ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); + ORT_ENFORCE(present_key != nullptr && present_value != nullptr, + "present_key/value outputs are required when past_key is provided."); + + // Zero present buffers before strided copy to avoid stale data in positions + // beyond past_seq that mha_fwd_kvcache might read during attention (matching GQA pattern). + // NOTE: This memset + strided copy is the main decode overhead vs contrib GQA. + // See PERFORMANCE NOTE above for why buffer sharing is impossible under the ONNX spec. + const size_t num_kv_rows = static_cast(parameters.batch_size) * parameters.kv_num_heads; + const size_t present_k_bytes = num_kv_rows * parameters.total_sequence_length * + parameters.head_size * sizeof(T); + const size_t present_v_bytes = num_kv_rows * parameters.total_sequence_length * + parameters.v_head_size * sizeof(T); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_key->MutableData(), 0, + present_k_bytes, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_value->MutableData(), 0, + present_v_bytes, cuda_stream)); + + // Copy past KV (BNSH) into present buffers (BNSH). Strided copy because + // past has [B, N_kv, past_seq, H] and present has [B, N_kv, total_seq, H]. + const size_t past_k_row_bytes = static_cast(parameters.past_sequence_length) * + parameters.head_size * sizeof(T); + const size_t present_k_row_bytes = static_cast(parameters.total_sequence_length) * + parameters.head_size * sizeof(T); + CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync( + present_key->MutableData(), present_k_row_bytes, + past_key->Data(), past_k_row_bytes, + past_k_row_bytes, num_kv_rows, + cudaMemcpyDeviceToDevice, cuda_stream)); + + const size_t past_v_row_bytes = static_cast(parameters.past_sequence_length) * + parameters.v_head_size * sizeof(T); + const size_t present_v_row_bytes = static_cast(parameters.total_sequence_length) * + parameters.v_head_size * sizeof(T); + CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync( + present_value->MutableData(), present_v_row_bytes, + past_value->Data(), past_v_row_bytes, + past_v_row_bytes, num_kv_rows, + cudaMemcpyDeviceToDevice, cuda_stream)); + + // seqlens_k: derive per-batch sequence lengths for the KV cache. + // mha_fwd_kvcache expects seqlens_k = tokens already in cache BEFORE appending new ones. + // When a bool mask is present, it encodes total valid token count (past + new). + // Subtract kv_sequence_length to get the pre-append count. + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + if (attn_mask != nullptr && attn_mask->IsDataType()) { + size_t mask_dims = attn_mask->Shape().NumDimensions(); + auto dims = attn_mask->Shape().GetDims(); + int64_t mask_dim0 = dims[0]; + int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0; + int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0; + // Offset: mask gives total valid count, subtract kv_sequence_length for pre-append count. + int seqlen_offset = -parameters.kv_sequence_length; + ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK( + attn_mask->Data(), seqlens_k_buffer.get(), + parameters.batch_size, parameters.total_sequence_length, + static_cast(mask_dims), mask_dim0, mask_dim1, mask_dim2, + cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset)); } else { - // GQA only supports 3D inputs (B, S, D) in BSNH format, not 4D inputs (B, num_heads, S, head_size) in BNSH format - if (!parameters.transpose_output) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "4D QKV inputs (BNSH format) are not supported yet in GQA path of Attention op (CUDA). " - "Please use 3D inputs (B, S, hidden_size) instead."); - } - // For now, GQA doesn't support qk_matmul_output_mode other than kNone - if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qk_matmul_output_mode is not supported yet in GQA path of Attention op (CUDA)."); - } - // GQA doesn't support softmax_precision yet - if (parameters.softmax_precision != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "softmax_precision is not supported yet in GQA path of Attention op (CUDA)."); - } - // causal attention is required for GQA - if (!parameters.is_causal) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "Non-causal attention is not supported yet in GQA path of Attention op (CUDA)."); - } - // GQA kernel expects K/V input sequence length == Q sequence length (self-attention only) - // Cross-attention (kv_sequence_length != q_sequence_length) is not supported - if (parameters.kv_sequence_length != parameters.q_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "Cross-attention (kv_sequence_length != q_sequence_length) is not supported in " - "GQA path of Attention op (CUDA). kv_sequence_length=", - parameters.kv_sequence_length, ", q_sequence_length=", parameters.q_sequence_length); - } + ORT_RETURN_IF_ERROR(LaunchFillInt32(seqlens_k_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); + } - auto& device_prop = GetDeviceProp(); - - // Bridge parameters to GroupQueryAttentionParameters - onnxruntime::contrib::GroupQueryAttentionParameters gqa_parameters; - gqa_parameters.batch_size = parameters.batch_size; - gqa_parameters.sequence_length = parameters.q_sequence_length; - gqa_parameters.seqlen_past_kv_cache = parameters.past_sequence_length; - gqa_parameters.seqlen_present_kv_cache = parameters.total_sequence_length; - gqa_parameters.total_sequence_length = parameters.total_sequence_length; - gqa_parameters.kv_sequence_length = parameters.kv_sequence_length; - gqa_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - gqa_parameters.num_heads = parameters.q_num_heads; - gqa_parameters.head_size = parameters.head_size; - gqa_parameters.v_head_size = parameters.v_head_size; - gqa_parameters.kv_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - gqa_parameters.kv_num_heads = parameters.kv_num_heads; - gqa_parameters.scale = parameters.scale; - gqa_parameters.softcap = parameters.softcap; - gqa_parameters.qkv_format = contribop_parameters.qkv_format; - - // Unset or set to default values for GQA-specific fields - gqa_parameters.rotary_dim = 0; // New Attention op doesn't use rotary embeddings directly - gqa_parameters.is_unidirectional = true; // GQA requires causal attention - gqa_parameters.is_packed_qkv = false; // New Attention op has separate Q, K, V inputs - gqa_parameters.is_subsequent_prompt = false; - gqa_parameters.is_first_prompt = parameters.past_sequence_length == 0; - gqa_parameters.do_rotary = false; // New Attention op doesn't use rotary embeddings - gqa_parameters.rotary_interleaved = false; - gqa_parameters.use_smooth_softmax = false; - gqa_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - gqa_parameters.past_kv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - gqa_parameters.local_window_size = -1; // No local window for standard attention - gqa_parameters.zeros_count = 0; - gqa_parameters.zero_ptr = nullptr; - gqa_parameters.num_splits = 1; - - // Construct GroupQueryAttentionData - typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data; - - // Scratch buffers for flash/memory efficient attention - IAllocatorUniquePtr k_buffer; - IAllocatorUniquePtr v_buffer; - IAllocatorUniquePtr fmha_buffer; - IAllocatorUniquePtr unpacked_qkv_buffer; - IAllocatorUniquePtr seq_lens_buffer; - IAllocatorUniquePtr seqlens_k_buffer; - - // Present KV cache buffers - GQA kernel uses these as working buffers - // If outputs are not provided, we allocate scratch buffers - IAllocatorUniquePtr present_key_scratch; - IAllocatorUniquePtr present_value_scratch; - - // Set input pointers - gqa_data.query = reinterpret_cast(Q->Data()); - gqa_data.key = reinterpret_cast(K->Data()); - gqa_data.value = reinterpret_cast(V->Data()); - gqa_data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - gqa_data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - // Set output pointers - gqa_data.output = reinterpret_cast(Y->MutableData()); - - // GQA kernel requires present_key/present_value buffers as working storage for KV cache - // Allocate scratch buffers if outputs are not provided - size_t present_kv_size = static_cast(parameters.batch_size) * - static_cast(parameters.kv_num_heads) * - static_cast(parameters.total_sequence_length) * - static_cast(parameters.head_size) * sizeof(CudaT); - if (present_key != nullptr) { - gqa_data.present_key = reinterpret_cast(present_key->MutableData()); - } else { - present_key_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream()); - gqa_data.present_key = reinterpret_cast(present_key_scratch.get()); - } - if (present_value != nullptr) { - gqa_data.present_value = reinterpret_cast(present_value->MutableData()); - } else { - present_value_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream()); - gqa_data.present_value = reinterpret_cast(present_value_scratch.get()); - } + // K/V new tokens: mha_fwd_kvcache expects BSNH for k_new/v_new. + // When !is_bsnh (4D BNSH input), transpose new tokens to BSNH. + const void* k_new = K->Data(); + const void* v_new = V->Data(); + IAllocatorUniquePtr k_bsnh_buffer; + IAllocatorUniquePtr v_bsnh_buffer; + if (!is_bsnh) { + size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.head_size; + size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.v_head_size; + k_bsnh_buffer = GetScratchBuffer(k_bytes, context->GetComputeStream()); + v_bsnh_buffer = GetScratchBuffer(v_bytes, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), k_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), v_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + k_new = k_bsnh_buffer.get(); + v_new = v_bsnh_buffer.get(); + } - // Compute past_present_share_buffer early since it's needed for flash attention path selection - gqa_parameters.past_present_share_buffer = (gqa_data.past_key == gqa_data.present_key); + // mha_fwd_kvcache: present_key/value as cache (BNSH), K/V as new tokens (BSNH). + // The kernel appends new tokens at position seqlens_k[b] and attends to all. + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, cuda_stream, + const_cast(q_data), + static_cast(present_key->MutableData()), + static_cast(present_value->MutableData()), + const_cast(k_new), const_cast(v_new), + out_data, + softmax_lse_buffer.get(), + static_cast(seqlens_k_buffer.get()), + /*rotary_cos=*/nullptr, /*rotary_sin=*/nullptr, + /*cache_batch_idx=*/nullptr, /*leftpad_k=*/nullptr, + /*head_sink=*/nullptr, /*block_table=*/nullptr, + parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads, + parameters.head_size, + parameters.q_sequence_length, parameters.total_sequence_length, + parameters.kv_sequence_length, /*rotary_dim=*/0, + parameters.scale, parameters.softcap, + parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false, + /*past_bsnh=*/false, // present cache is BNSH + static_cast(num_splits), + softmax_lse_accum_buffer.get(), out_accum_buffer.get(), + /*local_window_size=*/-1, /*is_rotary_interleaved=*/false, + /*is_packed_qkv=*/false)); - // Flash Attention buffers - IAllocatorUniquePtr softmax_lse_buffer; - IAllocatorUniquePtr softmax_lse_accum_buffer; - IAllocatorUniquePtr out_accum_buffer; + present_kv_already_populated = true; + } + // --- Path 3: Prompt flash (no past, no mask) --- + // Note: prompt with bool mask is handled by MEA (flash_eligible excludes it). + else { + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, cuda_stream, + const_cast(q_data), + const_cast(static_cast(K->Data())), + const_cast(static_cast(V->Data())), + out_data, + softmax_lse_buffer.get(), + parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads, + parameters.head_size, + parameters.q_sequence_length, parameters.kv_sequence_length, + parameters.scale, parameters.softcap, + parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false, + static_cast(num_splits), + softmax_lse_accum_buffer.get(), out_accum_buffer.get(), + is_bsnh)); + } - // Check Flash Attention support -#if USE_FLASH_ATTENTION - bool use_flash_attention = onnxruntime::flash::is_supported(device_prop, - gqa_parameters.head_size, - gqa_parameters.num_heads, - gqa_parameters.kv_num_heads); - - gqa_data.use_flash_attention = use_flash_attention; - gqa_data.use_flash_attention_fast_decode = use_flash_attention && - !gqa_parameters.is_first_prompt && - gqa_parameters.past_present_share_buffer; - - if (use_flash_attention) { - // Allocate Flash specific buffers (Softmax LSE, Accum) - size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size( - gqa_parameters.sequence_length, gqa_parameters.batch_size, gqa_parameters.num_heads); - - int num_heads_for_split = gqa_data.use_flash_attention_fast_decode - ? gqa_parameters.kv_num_heads - : gqa_parameters.num_heads; - auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = - onnxruntime::flash::get_num_splits_and_buffer_sizes( - gqa_parameters.batch_size, gqa_parameters.sequence_length, - gqa_parameters.total_sequence_length, num_heads_for_split, - gqa_parameters.head_size, device_prop.multiProcessorCount); - - gqa_parameters.num_splits = static_cast(num_splits); - - if (gqa_data.use_flash_attention_fast_decode && num_splits > 1) { - // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel. - // However, the LSE and Accum buffers must store results for ALL num_heads. - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length); - auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length, - round_multiple(gqa_parameters.head_size, 32)); - } - - softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); - - gqa_data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); - gqa_data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); - gqa_data.out_accum = reinterpret_cast(out_accum_buffer.get()); - } else { - gqa_data.softmax_lse = nullptr; - gqa_data.softmax_lse_accum = nullptr; - gqa_data.out_accum = nullptr; - } + // --- Transpose output BSNH → BNSH if input was 4D (BNSH) --- + if (!is_bsnh && out_bsnh_buffer != nullptr) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.q_sequence_length, + parameters.q_num_heads, parameters.v_head_size, + out_bsnh_buffer.get(), Y->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } + + // --- Populate present_key/value (BNSH) from K/V (BSNH) --- + // Skip for decode path where mha_fwd_kvcache already populated present buffers. + if (!present_kv_already_populated) { + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + } + + return Status::OK(); #else - gqa_data.use_flash_attention = false; - gqa_data.use_flash_attention_fast_decode = false; - gqa_data.softmax_lse = nullptr; - gqa_data.softmax_lse_accum = nullptr; - gqa_data.out_accum = nullptr; + ORT_UNUSED_PARAMETER(context); + ORT_UNUSED_PARAMETER(Q); + ORT_UNUSED_PARAMETER(K); + ORT_UNUSED_PARAMETER(V); + ORT_UNUSED_PARAMETER(attn_mask); + ORT_UNUSED_PARAMETER(past_key); + ORT_UNUSED_PARAMETER(past_value); + ORT_UNUSED_PARAMETER(nonpad_kv_seqlen); + ORT_UNUSED_PARAMETER(Y); + ORT_UNUSED_PARAMETER(present_key); + ORT_UNUSED_PARAMETER(present_value); + ORT_UNUSED_PARAMETER(parameters); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Flash attention is not available in this build."); #endif +} - // Check Memory Efficient Attention support (fallback if flash attention not available) +// ============================================================================ +// RunMemoryEfficientAttention: Direct memory-efficient attention kernel call +// ============================================================================ +// +// Memory Efficient Attention (cutlass FMHA) dispatch paths: +// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode +// Path 2: no past, with mask (prompt) -> standard MEA with additive bias +// Path 3: no past, no mask (prompt) -> standard MEA +// Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype, +// head_size <= 1024), plus: no output_qk, no past_key (decode excluded), +// bias stride alignment. +// Note: softcap and softmax_precision are early-rejected before the cascade. +// +template +Status Attention::RunMemoryEfficientAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const { #if USE_MEMORY_EFFICIENT_ATTENTION - if (!gqa_data.use_flash_attention) { - int sm = (device_prop.major * 10) + device_prop.minor; - bool use_memory_efficient_attention = - onnxruntime::contrib::cuda::has_memory_efficient_attention( - sm, std::is_same::value, std::is_same::value, - gqa_parameters.head_size, gqa_parameters.head_size); - gqa_data.use_memory_efficient_attention = use_memory_efficient_attention; - - // KV buffer for head expansion (when num_heads != kv_num_heads) - size_t kv_buffer_bytes = (use_memory_efficient_attention && - (gqa_parameters.num_heads != gqa_parameters.kv_num_heads)) - ? (sizeof(T) * gqa_parameters.batch_size * gqa_parameters.num_heads * - gqa_parameters.seqlen_present_kv_cache * gqa_parameters.head_size) - : 0; - // FMHA workspace - size_t fmha_buffer_bytes = - (use_memory_efficient_attention && - onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace( - gqa_parameters.head_size, sizeof(T) == sizeof(float))) - ? (sizeof(float) * gqa_parameters.batch_size * gqa_parameters.sequence_length * - gqa_parameters.num_heads * gqa_parameters.head_size) - : 0; - - k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); - - gqa_data.k = reinterpret_cast(k_buffer.get()); - gqa_data.v = reinterpret_cast(v_buffer.get()); - gqa_data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); - } else { - gqa_data.use_memory_efficient_attention = false; - gqa_data.k = nullptr; - gqa_data.v = nullptr; - gqa_data.fmha_buffer = nullptr; - } -#else - gqa_data.use_memory_efficient_attention = false; - gqa_data.k = nullptr; - gqa_data.v = nullptr; - gqa_data.fmha_buffer = nullptr; -#endif + ORT_UNUSED_PARAMETER(past_key); + ORT_UNUSED_PARAMETER(past_value); + auto& device_prop = GetDeviceProp(); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + const bool is_bsnh = parameters.transpose_output; + const int sm = device_prop.major * 10 + device_prop.minor; - // Centralized scratch buffer allocation using GQABufferRequirements - auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute( - gqa_parameters, - false, // use_xqa - gqa_data.use_flash_attention, - gqa_data.use_flash_attention_fast_decode, - gqa_data.use_memory_efficient_attention); - - if (buffer_req.qkv_buffer_bytes > 0) { - unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); - gqa_data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); - } else { - gqa_data.qkv_buffer = nullptr; - } + // Q/K/V pointers — MEA expects BSNH format for Q + const void* q_data = Q->Data(); + const void* k_data = K->Data(); + const void* v_data = V->Data(); - // Allocate GPU buffer for seqlens_k (total_sequence_length - 1) for GQA compatibility - // The GQA kernel expects sequence length information for flash/memory efficient attention - seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); - - // GQA only supports masking, not additive bias. - // For bool mask, we need to convert it to sequence lengths on GPU. - // Note: The GQA path interprets 2D bool masks as (batch_size, total_seq_len) since it converts - // masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from - // the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so - // 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast. - if (attn_mask != nullptr && attn_mask->IsDataType()) { - // Allocate validation result buffer on GPU - // Get mask dimensions for broadcasting - // attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len) - const auto& mask_shape = attn_mask->Shape(); - int mask_dims = static_cast(mask_shape.NumDimensions()); - int64_t mask_dim0 = 0, mask_dim1 = 0, mask_dim2 = 0; - - if (mask_dims == 2) { - // Shape: (batch_size or 1, total_seq_len) - mask_dim0 = mask_shape[0]; - mask_dim1 = 0; - mask_dim2 = 0; - } else if (mask_dims == 3) { - // Shape: (num_heads or 1, q_seq_len, total_seq_len) - mask_dim0 = mask_shape[0]; - mask_dim1 = mask_shape[1]; - mask_dim2 = 0; - } else if (mask_dims == 4) { - // Shape: (batch_size or 1, num_heads or 1, q_seq_len, total_seq_len) - mask_dim0 = mask_shape[0]; - mask_dim1 = mask_shape[1]; - mask_dim2 = mask_shape[2]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Boolean attn_mask must be 2D, 3D, or 4D. Got ", mask_dims, "D."); - } - - // Launch CUDA kernel to convert mask to seqlens_k and validate - // Mask validity (right-padding, contiguous) is checked asynchronously via CUDA_KERNEL_ASSERT. - ORT_RETURN_IF_ERROR(LaunchConvertMaskToSeqlensK( - attn_mask->Data(), - seqlens_k_buffer.get(), - parameters.batch_size, - parameters.total_sequence_length, - mask_dims, - mask_dim0, - mask_dim1, - mask_dim2, - cuda_stream, - device_prop.maxThreadsPerBlock)); - } else if (attn_mask != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA)."); - } else { - // No mask provided - use full sequence length for all batches - // seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention) - // Fill on GPU using cudaMemset-like approach or a simple kernel - std::vector seqlens_k_host(parameters.batch_size, parameters.total_sequence_length - 1); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(seqlens_k_buffer.get(), seqlens_k_host.data(), - sizeof(int) * parameters.batch_size, - cudaMemcpyHostToDevice, cuda_stream)); - } + // --- Transpose Q from BNSH to BSNH if 4D input --- + IAllocatorUniquePtr q_bsnh_buffer; + if (!is_bsnh) { + size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.head_size; + q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.q_sequence_length, + parameters.q_num_heads, parameters.head_size, + Q->Data(), q_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + q_data = q_bsnh_buffer.get(); + } + + // MEA output is BSNH. If Y expects BNSH, write to scratch then transpose. + void* out_data = Y->MutableData(); + IAllocatorUniquePtr out_bsnh_buffer; + if (!is_bsnh) { + size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.v_head_size; + out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_data = out_bsnh_buffer.get(); + } - // Process seqlens_k to compute past_seq_lens, total_seq_lens, and padded_seq_lens - // This is always needed for flash/memory efficient attention - seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream()); - gqa_data.past_seq_lens = seq_lens_buffer.get(); - gqa_data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size; - gqa_data.padded_seq_lens = gqa_data.total_seq_lens + parameters.batch_size; - - ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchGetSequenceLengths( - seqlens_k_buffer.get(), - gqa_data.past_seq_lens, - gqa_data.total_seq_lens, - gqa_data.padded_seq_lens, - parameters.batch_size, - parameters.q_sequence_length, - gqa_parameters.is_first_prompt, + // GQA head expansion: MEA requires matching num_heads for Q/K/V. + // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup. + const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads; + IAllocatorUniquePtr k_expand_buffer; + IAllocatorUniquePtr v_expand_buffer; + + if (is_gqa) { + // GQA+MEA only works with fp16/bf16 (LaunchUngroup lacks fp32 template instantiation + // in group_query_attention_impl.cu). + // Use if constexpr to avoid instantiating LaunchUngroup. + if constexpr (std::is_same_v) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "GQA with Memory Efficient Attention requires fp16 or bf16, not fp32."); + } else { + ORT_ENFORCE(parameters.head_size == parameters.v_head_size, + "GQA with MEA requires head_size == v_head_size for LaunchUngroup."); + ORT_ENFORCE(parameters.head_size % 4 == 0, + "GQA with MEA requires head_size divisible by 4 for LaunchUngroup (float2 access)."); + const size_t expanded_kv_elements = static_cast(parameters.batch_size) * + static_cast(parameters.total_sequence_length) * + static_cast(parameters.q_num_heads) * + static_cast(parameters.head_size); + k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); + v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); + + onnxruntime::contrib::GroupQueryAttentionParameters ungroup_params = {}; + ungroup_params.batch_size = parameters.batch_size; + ungroup_params.num_heads = parameters.q_num_heads; + ungroup_params.kv_num_heads = parameters.kv_num_heads; + ungroup_params.head_size = parameters.head_size; + + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchUngroup( + ungroup_params, + reinterpret_cast(k_expand_buffer.get()), + reinterpret_cast(v_expand_buffer.get()), + reinterpret_cast(k_data), + reinterpret_cast(v_data), + parameters.total_sequence_length, + parameters.total_sequence_length, + is_bsnh, cuda_stream, device_prop.maxThreadsPerBlock)); - // Set GQA-specific fields - gqa_data.cos_cache = nullptr; // No rotary embeddings - gqa_data.sin_cache = nullptr; - gqa_data.head_sink = nullptr; - gqa_data.position_ids = nullptr; - - // Call GQA kernel (with flash or memory efficient attention) - cublasHandle_t cublas = GetCublasHandle(context); - - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data); - } // else (non-float GQA path) - } else { // MHA path (kv_num_heads == q_num_heads) - typedef typename ToCudaType::MappedType CudaT; - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size) - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - - // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask - // So mask_type should always be MASK_NONE since we don't have a separate padding mask input - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - - // Determine broadcast flags for attention_bias (if it exists) - // The MHA path uses attn_mask as attention_bias (additive bias added before softmax). - // Bool masks are element-wise converted to additive bias (true → 0.0, false → -inf), - // preserving the original shape, so the same broadcasting rules apply to both types. - // - // ONNX broadcasting is right-aligned to target shape (batch, heads, q_seq, total_seq): - // 2D [A, B] → [1, 1, A, B] : A = q_seq_len, B = total_seq_len - // 3D [A, B, C] → [1, A, B, C] : A = heads, B = q_seq_len, C = total_seq_len - // 4D [A, B, C, D] → [A, B, C, D] : A = batch, B = heads, C = q_seq_len, D = total_seq_len - // - // Note: A 2D mask cannot represent per-batch padding because the batch dimension is broadcast. - // For per-batch boolean padding masks, use 4D shape (batch, 1, 1, total_seq_len). + k_data = k_expand_buffer.get(); + v_data = v_expand_buffer.get(); + } + } + + // Note: MEA with past_key/value is handled by the unfused fallback. + // The cascade in ComputeInternal ensures past_key == nullptr when we reach here. + + // Handle attention mask → attention_bias conversion + IAllocatorUniquePtr converted_mask_buffer; + const void* attn_bias_data = nullptr; + bool broadcast_bias_dim_0 = false; + bool broadcast_bias_dim_1 = false; + + if (nonpad_kv_seqlen != nullptr) { + // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. + // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( + nonpad_kv_seqlen->Data(), + seqlens_k_buffer.get(), + parameters.batch_size, + parameters.total_sequence_length, + cuda_stream, + device_prop.maxThreadsPerBlock)); + + // When attn_mask is also provided, convert it to additive attn_bias so MEA + // applies both custom right padding (seqlens_k) and the attention mask (attn_bias). if (attn_mask != nullptr) { - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting - // For 3D mask (heads_or_1, q_seq_len, total_seq_len): batch always broadcasts, heads broadcasts if dim[0]==1 - // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1 - - if (attn_mask_dims_size == 2) { - // 2D mask: both dimensions need broadcasting - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - // 3D mask [A, q_seq_len, total_seq_len]: right-aligned to [_, A, q_seq, total_seq] - // A maps to heads dimension (validated to be 1 or q_num_heads by attention_helper.h) - // Batch dimension is missing, so always broadcasts - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1; - } else { - // 4D mask: check both dim 0 and dim 1 explicitly - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; - } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; + ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, + device_prop.maxThreadsPerBlock, + converted_mask_buffer, attn_bias_data, + broadcast_bias_dim_0, broadcast_bias_dim_1)); } - contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now - if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && - qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { - ORT_THROW("qk_matmul_output_mode other than -1 (None) and 0 (QK) is not supported yet in Attention op (CUDA)."); + onnxruntime::contrib::cuda::MemoryEfficientAttentionParams p; + p.sm = sm; + p.is_half = std::is_same::value; + p.is_bf16 = std::is_same::value; + p.is_kv_bsnh = is_bsnh; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.q_num_heads; + p.sequence_length = parameters.q_sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_causal; + p.scale = parameters.scale; + p.seqlen_k_ptr = seqlens_k_buffer.get(); + p.has_custom_right_padding = true; + p.broadcast_attn_bias_dim_0 = broadcast_bias_dim_0; + p.broadcast_attn_bias_dim_1 = broadcast_bias_dim_1; + p.query = q_data; + p.key = k_data; + p.value = v_data; + p.attn_bias = attn_bias_data; + p.stream = cuda_stream; + p.output = out_data; + + IAllocatorUniquePtr workspace_buffer; + if (onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace( + parameters.v_head_size, sizeof(T) == sizeof(float))) { + size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.v_head_size; + workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + p.workspace = workspace_buffer.get(); + } else { + p.workspace = nullptr; } - // TODO(titaiwang, xadupre): softcap and softmax_precision are not used yet - if (parameters.softcap != 0.0f) { - ORT_THROW("softcap is not supported yet in Attention op (CUDA)."); + onnxruntime::contrib::cuda::run_memory_efficient_attention(p); + } + // Standard MEA path: float attention bias, bool mask (converted to bias), or no mask. + // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value) + // which correctly handles all-false masks (uniform softmax weights) unlike the + // custom_right_padding seqlens approach which would produce NaN. + else { + if (attn_mask != nullptr) { + ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, + device_prop.maxThreadsPerBlock, + converted_mask_buffer, attn_bias_data, + broadcast_bias_dim_0, broadcast_bias_dim_1)); } - if (parameters.softmax_precision != 0) { - ORT_THROW("softmax_precision is not supported yet in Attention op (CUDA)."); + + onnxruntime::contrib::cuda::MemoryEfficientAttentionParams p; + p.sm = sm; + p.is_half = std::is_same::value; + p.is_bf16 = std::is_same::value; + p.is_kv_bsnh = is_bsnh; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.q_num_heads; + p.sequence_length = parameters.q_sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_causal; + p.scale = parameters.scale; + p.broadcast_attn_bias_dim_0 = broadcast_bias_dim_0; + p.broadcast_attn_bias_dim_1 = broadcast_bias_dim_1; + p.query = q_data; + p.key = k_data; + p.value = v_data; + p.attn_bias = attn_bias_data; + p.stream = cuda_stream; + p.output = out_data; + + IAllocatorUniquePtr workspace_buffer; + if (onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace( + parameters.v_head_size, sizeof(T) == sizeof(float))) { + size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * + parameters.q_num_heads * parameters.v_head_size; + workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + p.workspace = workspace_buffer.get(); + } else { + p.workspace = nullptr; } + onnxruntime::contrib::cuda::run_memory_efficient_attention(p); + } + + // --- Transpose output BSNH → BNSH if input was 4D (BNSH) --- + if (!is_bsnh && out_bsnh_buffer != nullptr) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.q_sequence_length, + parameters.q_num_heads, parameters.v_head_size, + out_bsnh_buffer.get(), Y->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } + + // Populate present_key/present_value (BNSH) if requested + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + + return Status::OK(); +#else + ORT_UNUSED_PARAMETER(context); + ORT_UNUSED_PARAMETER(Q); + ORT_UNUSED_PARAMETER(K); + ORT_UNUSED_PARAMETER(V); + ORT_UNUSED_PARAMETER(attn_mask); + ORT_UNUSED_PARAMETER(past_key); + ORT_UNUSED_PARAMETER(past_value); + ORT_UNUSED_PARAMETER(nonpad_kv_seqlen); + ORT_UNUSED_PARAMETER(Y); + ORT_UNUSED_PARAMETER(present_key); + ORT_UNUSED_PARAMETER(present_value); + ORT_UNUSED_PARAMETER(parameters); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Memory efficient attention is not available in this build."); +#endif +} + +// ============================================================================ +// RunUnfusedAttention: Delegates to MHA's QkvToContext (unfused GEMM+softmax+GEMM) +// ============================================================================ +// +// Unfused Attention dispatch paths: +// Universal fallback via MHA's QkvToContext. +// Path 1: nonpad_kv_seqlen only -> converts to attention_bias [B, q_seq, total_seq] +// Path 2: nonpad_kv_seqlen + attn_mask -> composes both into attention_bias [B, q_seq, total_seq] +// (nonpad bias + mask bias added element-wise with cyclic broadcasting) +// Path 3: all other cases -> passes mask/bias directly +// Supports: all dtypes (fp16/bf16/fp32), all mask types (bool/float/none), all head sizes +// Not supported: softcap, softmax_precision, output_qk modes beyond kNone/kQK +// Limitation: MHA only (q_num_heads must equal kv_num_heads) +// +template +Status Attention::RunUnfusedAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, + const attention_helper::AttentionParameters& parameters) const { + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + + // Bridge to contrib::AttentionParameters for the MHA unfused path + onnxruntime::contrib::AttentionParameters contribop_parameters; + + if (!parameters.transpose_output) { + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; + contribop_parameters.is_output_bnsh = true; + } else { + contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; + contribop_parameters.is_output_bnsh = false; + } + + contribop_parameters.batch_size = parameters.batch_size; + contribop_parameters.sequence_length = parameters.q_sequence_length; + contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; + contribop_parameters.past_sequence_length = parameters.past_sequence_length; + contribop_parameters.total_sequence_length = parameters.total_sequence_length; + contribop_parameters.max_sequence_length = parameters.total_sequence_length; + contribop_parameters.input_hidden_size = 0; + contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; + contribop_parameters.head_size = parameters.head_size; + contribop_parameters.v_head_size = parameters.v_head_size; + contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; + contribop_parameters.num_heads = parameters.q_num_heads; + contribop_parameters.rotary_dim = 0; + contribop_parameters.num_splits = 1; + contribop_parameters.beam_width = 1; + contribop_parameters.is_unidirectional = parameters.is_causal; + contribop_parameters.past_present_share_buffer = false; + contribop_parameters.is_packed_qkv = false; + contribop_parameters.do_rotary = false; + contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; + contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); + contribop_parameters.scale = parameters.scale; + contribop_parameters.use_tf32 = UseTF32(); - // Construct AttentionData to pass to QkvToContext - onnxruntime::contrib::cuda::AttentionData data; - - // Set input pointers - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; // New Attention op doesn't have key_padding_mask - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - - // Set output pointers - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (nullptr != output_qk) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); + // Determine broadcast flags for attention_bias + if (attn_mask != nullptr) { + size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); + auto attn_mask_dims = attn_mask->Shape().GetDims(); + if (attn_mask_dims_size == 2) { + contribop_parameters.broadcast_attn_bias_dim_0 = true; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else if (attn_mask_dims_size == 3) { + contribop_parameters.broadcast_attn_bias_dim_0 = true; + contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1; + } else { + contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; + contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; } + } else { + contribop_parameters.broadcast_attn_bias_dim_0 = false; + contribop_parameters.broadcast_attn_bias_dim_1 = false; + } + + // Construct AttentionData + onnxruntime::contrib::cuda::AttentionData data; + data.query = reinterpret_cast(Q->Data()); + data.key = reinterpret_cast(K->Data()); + data.value = reinterpret_cast(V->Data()); + data.mask_index = nullptr; + data.mask_index_dims = gsl::span(); + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + data.output = reinterpret_cast(Y->MutableData()); + data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); + if (output_qk != nullptr) { + data.output_qk = reinterpret_cast(output_qk->MutableData()); + } + data.bias = nullptr; + + // Handle attention mask / nonpad_kv_seqlen → attention_bias + IAllocatorUniquePtr converted_mask_buffer; + IAllocatorUniquePtr mask_bias_buffer; // temp buffer for mask→bias when composing + if (nonpad_kv_seqlen != nullptr) { + // Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq] + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + int64_t bias_elements = static_cast(parameters.batch_size) * + parameters.q_sequence_length * + parameters.total_sequence_length; + converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), context->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( + nonpad_kv_seqlen->Data(), + reinterpret_cast(converted_mask_buffer.get()), + parameters.batch_size, + parameters.q_sequence_length, + parameters.total_sequence_length, + contribop_parameters.mask_filter_value, + cuda_stream, + device_prop.maxThreadsPerBlock)); + + // When attn_mask is also present, compose it into the nonpad bias additively. + // The nonpad bias is [B, q, t]; the mask is added with cyclic broadcasting + // (e.g. a 2D [q, t] mask repeats over the batch dimension). + // Only 2D masks and 4D masks with head_dim=1 are supported — per-head masks + // (3D [H,q,t] or 4D [B,H>1,q,t]) cannot be composed into a [B,q,t] buffer. + if (attn_mask != nullptr) { + const auto& mask_shape = attn_mask->Shape(); + int mask_dims = static_cast(mask_shape.NumDimensions()); + ORT_ENFORCE(mask_dims == 2 || (mask_dims == 4 && mask_shape[1] == 1), + "nonpad_kv_seqlen + attn_mask composition in unfused path only supports " + "2D masks [q, t] and 4D masks with head_dim=1 [B, 1, q, t]. " + "Got mask shape: ", + mask_shape); + + int64_t mask_elements = mask_shape.Size(); + const NativeCudaT* mask_bias_ptr = nullptr; - // Set additional fields - data.bias = nullptr; // New Attention op doesn't have bias - IAllocatorUniquePtr converted_mask_buffer; - if (nullptr != attn_mask) { if (attn_mask->IsDataType()) { - // Convert boolean mask to additive attention bias: true -> 0.0, false -> mask_filter_value. - // The conversion is element-wise and preserves the original shape, so the broadcast flags - // set above apply identically to the converted float buffer. - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; - int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), context->GetComputeStream()); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + // Convert bool mask to additive bias in a temp buffer, then add in-place. + mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), context->GetComputeStream()); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), - reinterpret_cast(converted_mask_buffer.get()), - num_elements, + reinterpret_cast(mask_bias_buffer.get()), + mask_elements, contribop_parameters.mask_filter_value, cuda_stream, - GetDeviceProp().maxThreadsPerBlock)); - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); + device_prop.maxThreadsPerBlock)); + mask_bias_ptr = reinterpret_cast(mask_bias_buffer.get()); } else { - data.attention_bias = reinterpret_cast(attn_mask->Data()); + // Float mask is already in additive bias format. + mask_bias_ptr = reinterpret_cast(attn_mask->Data()); } + + // Add mask bias into nonpad bias with cyclic broadcasting. + // 2D mask [q, t]: mask_elements = q*t, repeats for each batch → correct. + // 4D mask [B, 1, q, t]: mask_elements = B*q*t = bias_elements → direct add. + ORT_RETURN_IF_ERROR(LaunchAddBiasInPlace( + reinterpret_cast(converted_mask_buffer.get()), + mask_bias_ptr, + bias_elements, + mask_elements, + cuda_stream, + device_prop.maxThreadsPerBlock)); + } + + data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); + // Composed bias is [B, q_seq, total_seq] → broadcasts over heads but not batch. + contribop_parameters.broadcast_attn_bias_dim_0 = false; + contribop_parameters.broadcast_attn_bias_dim_1 = true; + } else if (attn_mask != nullptr) { + if (attn_mask->IsDataType()) { + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + int64_t num_elements = attn_mask->Shape().Size(); + converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( + attn_mask->Data(), + reinterpret_cast(converted_mask_buffer.get()), + num_elements, + contribop_parameters.mask_filter_value, + cuda_stream, + device_prop.maxThreadsPerBlock)); + data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); + } else { + data.attention_bias = reinterpret_cast(attn_mask->Data()); + } + } + + data.qkv_format = contribop_parameters.qkv_format; + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; + data.fused_runner = nullptr; + data.fused_cross_attention_kernel = nullptr; + data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; + + // Allocate workspace + const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); + size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( + sizeof(T), + contribop_parameters.batch_size, + contribop_parameters.num_heads, + contribop_parameters.head_size, + contribop_parameters.v_head_size, + contribop_parameters.sequence_length, + contribop_parameters.kv_sequence_length, + contribop_parameters.total_sequence_length, + nullptr, false, false, false, false, false, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); + + return onnxruntime::contrib::cuda::QkvToContext( + device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); +} + +// ============================================================================ +// ComputeInternal: Dispatch to appropriate attention kernel +// ============================================================================ +// MHA path (q_num_heads == kv_num_heads): uses direct kernel dispatch cascade +// flash → memory efficient → unfused +// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively) or MEA +// (with head expansion via LaunchUngroup). Unfused fallback not yet supported for GQA. +// ============================================================================ +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* Q = context->Input(0); + const Tensor* K = context->Input(1); + const Tensor* V = context->Input(2); + const Tensor* attn_mask = context->Input(3); + const Tensor* past_key = context->Input(4); + const Tensor* past_value = context->Input(5); + const Tensor* nonpad_kv_seqlen = context->Input(6); // optional, Opset 24 + + // When both nonpad_kv_seqlen and attn_mask are provided, Flash Attention cannot handle + // the combination (no bias parameter). Route to MEA or Unfused which support composition. + if (nonpad_kv_seqlen != nullptr && attn_mask != nullptr) { + LOGS_DEFAULT(VERBOSE) << "Both nonpad_kv_seqlen and attn_mask provided. " + << "Flash Attention does not support this combination; " + << "falling back to Memory Efficient Attention or Unfused path."; + } + + attention_helper::AttentionParameters parameters; + TensorShape y_shape; + TensorShape present_key_shape; + TensorShape present_value_shape; + TensorShape output_qk_shape; + + ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( + Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen, + is_causal_, softcap_, softmax_precision_, + qk_matmul_output_mode_, kv_num_heads_, q_num_heads_, scale_, + parameters, y_shape, present_key_shape, present_value_shape, output_qk_shape, + true /* skip_nonpad_data_validation: data is on GPU */) + .IsOK(), + "Output shapes for Attention could not be computed."); + + Tensor* Y = context->Output(0, y_shape); + Tensor* present_key = context->Output(1, present_key_shape); + Tensor* present_value = context->Output(2, present_value_shape); + Tensor* output_qk = context->Output(3, output_qk_shape); + + const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads; + + // === KERNEL SELECTION CASCADE === + // Priority: flash attention > memory efficient attention > unfused attention + // + // 4D BNSH handling per kernel: + // Flash: strictly requires BSNH — Q is transposed BNSH→BSNH before calling mha_fwd*. + // K/V passed as BNSH to mha_fwd_kvcache (it handles both layouts). + // MEA: accepts both BSNH and BNSH natively via is_kv_bsnh flag. Q transposed to BSNH. + // Unfused: accepts both via QkvToContext's qkv_format (Q_K_V_BSNH or Q_K_V_BNSH). + // + // nonpad_kv_seqlen + attn_mask routing: + // Flash: cannot handle this combo (no bias param when seqlens_k is used) → excluded. + // MEA: supports both (custom_right_padding for seqlens + additive attn_bias for mask). + // Unfused: nonpad → attention_bias; mask composed additively when both present. + const bool has_output_qk = (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone); + + // Early-reject features not supported by any CUDA kernel path. + // TODO(titaiwang): Support softcap and softmax_precision on CUDA kernels. + // When a kernel adds support, move these checks to the unfused fallback section. + if (parameters.softcap != 0.0f) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "softcap is not supported yet in Attention op (CUDA)."); + } + if (parameters.softmax_precision != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "softmax_precision is not supported yet in Attention op (CUDA)."); + } + +#if USE_FLASH_ATTENTION + { + auto& device_prop = GetDeviceProp(); + bool flash_eligible = + !disable_flash_attention_ && + !std::is_same::value && + onnxruntime::flash::is_supported(device_prop, parameters.head_size, + parameters.q_num_heads, parameters.kv_num_heads) && + parameters.head_size == parameters.v_head_size && + !has_output_qk && + // Bool masks without past_key (prompt) can't use flash because mha_fwd_kvcache's + // causal semantics are decode-oriented (window offset by seqlens_k). For causal + // prompt with padding, MEA handles it correctly via attention bias conversion. + // Flash handles: no mask, decode with past (±mask), nonpad_kv_seqlen. + // Note: contrib MHA similarly excludes flash when attention_bias is present + // (no mask support in mha_fwd). Float masks and bool prompt masks route to MEA + // which supports additive bias natively. + (attn_mask == nullptr || (attn_mask->IsDataType() && past_key != nullptr)) && + // Flash cannot handle nonpad_kv_seqlen + attn_mask simultaneously (no bias parameter + // in mha_fwd/mha_fwd_kvcache when seqlens_k is used). Route to MEA instead. + !(nonpad_kv_seqlen != nullptr && attn_mask != nullptr); + + if (flash_eligible) { + return RunFlashAttention(context, Q, K, V, attn_mask, past_key, past_value, + nonpad_kv_seqlen, Y, present_key, present_value, parameters); } - data.qkv_format = contribop_parameters.qkv_format; - - // For now, set flags to false and let QkvToContext use the unfused path - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace for Q, K, V processing and scratch buffer - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, // fused_runner - false, // use_flash_attention - false, // use_lean_attention - false, // use_fused_cross_attention - false, // use_memory_efficient_attention - false, // use_cudnn_flash_attention - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - // Call QkvToContext to perform the attention computation + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + { auto& device_prop = GetDeviceProp(); - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); + int sm = device_prop.major * 10 + device_prop.minor; + bool mea_eligible = + !disable_memory_efficient_attention_ && + onnxruntime::contrib::cuda::has_memory_efficient_attention( + sm, std::is_same::value, std::is_same::value, + parameters.head_size, parameters.v_head_size) && + !has_output_qk && + past_key == nullptr; + + // Cutlass FMHA requires bias strides to satisfy minimum alignment even in the + // "unaligned" kernel path. When an attention mask is present (with or without + // nonpad_kv_seqlen), it becomes an additive bias with bias_strideM = + // total_sequence_length. Skip MEA if this stride can't satisfy the kernel's + // minimum alignment requirement. + if (mea_eligible && attn_mask != nullptr) { + int min_bias_align = 1; + if ((std::is_same::value && sm >= 80) || + (!std::is_same::value && sm >= 75)) { + min_bias_align = 4; // TensorOp on Sm80+ (float) or Sm75+ (fp16/bf16) + } else if (!std::is_same::value && sm >= 70) { + min_bias_align = 2; // TensorOp on Volta (fp16) + } + if (parameters.total_sequence_length % min_bias_align != 0) { + mea_eligible = false; + } + } + + if (mea_eligible) { + return RunMemoryEfficientAttention(context, Q, K, V, attn_mask, past_key, past_value, + nonpad_kv_seqlen, Y, present_key, present_value, parameters); + } + } +#endif - // QkvToContext takes two template parameters: T for computation type, QK for output_qk type - // For now, both are the same type (CudaT) + // Fallback: unfused attention + // TODO(titaiwang): Support additional output_qk modes beyond kNone and kQK. + // Currently only unfused handles output_qk, and only kNone/kQK modes. + if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && + qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "qk_matmul_output_mode other than kNone and kQK is not supported yet " + "in Attention op (CUDA)."); + } - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); + if (is_gqa) { + // TODO(titaiwang): Support GQA in unfused attention path for fp32/old-GPU fallback. + // Currently blocked because QkvToContext allocates K/V workspace assuming + // num_heads == kv_num_heads. GQA needs a head expansion step (ExpandKVHeads kernel) + // to replicate kv_num_heads -> q_num_heads before unfused can process. + // Requires ~160 lines. See issue #27516. + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GQA (q_num_heads != kv_num_heads) requires flash or memory efficient attention, " + "but neither is eligible. Ensure fp16/bf16 on Ampere+ GPU, or check head_size constraints."); } + + return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, + nonpad_kv_seqlen, Y, present_key, present_value, output_qk, parameters); } + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index 951b3b2e2f3c1..3f69c7a77f497 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -14,6 +14,42 @@ class Attention final : public CudaKernel { Attention(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; + private: + Status RunFlashAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const; + + Status RunMemoryEfficientAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const; + + Status RunUnfusedAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, + const attention_helper::AttentionParameters& parameters) const; + + Status ConvertAttnMaskToBias( + OpKernelContext* context, + const Tensor* attn_mask, + cudaStream_t cuda_stream, + int max_threads_per_block, + IAllocatorUniquePtr& converted_mask_buffer, + const void*& attn_bias_data, + bool& broadcast_bias_dim_0, + bool& broadcast_bias_dim_1) const; + protected: bool is_causal_; int kv_num_heads_; @@ -22,6 +58,8 @@ class Attention final : public CudaKernel { float scale_; float softcap_; int softmax_precision_; + bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index 08f93d48ebcaa..a1febc8a26a44 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -15,11 +15,11 @@ namespace cuda { // where padding starts. The sequence length is the index of first False. // // Validation (via CUDA_KERNEL_ASSERT, reported asynchronously): -// - The mask must start with True (first element must be True) +// - All-false masks are valid (represents fully masked / zero-length sequence) // - After the first False, all remaining elements must be False (contiguous padding) // // Handle broadcasting: -// - 2D mask (batch_size, total_seq_len): stride = total_seq_len, batch_idx = threadIdx +// - 2D mask (q_seq_len, total_seq_len): broadcasts over batch; uses first query position (row 0) // - 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts to [1, num_heads, q_seq, total_seq] // No per-batch variation; uses first head, first q position for all batches // - 4D mask (B, H, q_seq_len, total_seq_len): we look at first head, first q position @@ -31,7 +31,8 @@ __global__ void ConvertMaskToSeqlensKernel( const int mask_dims, const int64_t mask_dim0, const int64_t mask_dim1, - const int64_t mask_dim2) { + const int64_t mask_dim2, + const int seqlen_offset) { int batch_idx = threadIdx.x + blockIdx.x * blockDim.x; if (batch_idx >= batch_size) { return; @@ -42,10 +43,11 @@ __global__ void ConvertMaskToSeqlensKernel( const bool* mask_row = nullptr; if (mask_dims == 2) { - // Shape: (batch_size or 1, total_seq_len) - // If mask_dim0 == 1, broadcast across all batches - int effective_batch = (mask_dim0 == 1) ? 0 : batch_idx; - mask_row = attn_mask + effective_batch * total_seq_len; + // Shape: (q_seq_len, total_seq_len) per ONNX spec. Broadcasts over batch. + // Use first query position (row 0) for sequence length determination. + // For 2D masks [q_seq, total_seq], only used in decode path where q_seq=1, + // so row 0 is always correct. Flash excludes 2D bool masks for prompt. + mask_row = attn_mask; } else if (mask_dims == 3) { // Shape: (num_heads, q_seq_len, total_seq_len) // This broadcasts to [1, num_heads, q_seq, total_seq] - same mask for all batches @@ -72,32 +74,41 @@ __global__ void ConvertMaskToSeqlensKernel( mask_row = attn_mask + effective_batch * batch_stride + h_idx * head_stride + q_idx * q_stride; } - // Validate that mask starts with True (right-padding convention) - CUDA_KERNEL_ASSERT(mask_row[0]); // mask must start with True - // Find the first False (where padding starts) // All elements before this should be True, all after should be False - int seq_len = total_seq_len; // Default: all True (no padding) - bool found_first_false = false; - - for (int i = 1; i < total_seq_len; ++i) { - bool current = mask_row[i]; - - if (!found_first_false && !current) { - // Found first False - this is where padding starts - seq_len = i; - found_first_false = true; - } else if (found_first_false && current) { - // Found True after False - mask is not contiguous (invalid) - CUDA_KERNEL_ASSERT(false); // mask must be contiguous (no True after False) + int seq_len; + if (!mask_row[0]) { + // Entire row is padding (all-false mask) + seq_len = 0; + } else { + seq_len = total_seq_len; // Default: all True (no padding) + bool found_first_false = false; + + for (int i = 1; i < total_seq_len; ++i) { + bool current = mask_row[i]; + + if (!found_first_false && !current) { + // Found first False - this is where padding starts + seq_len = i; + found_first_false = true; + } else if (found_first_false && current) { + // Found True after False - mask is not contiguous (invalid) + CUDA_KERNEL_ASSERT(false); // mask must be contiguous (no True after False) + } } } - // seqlens_k is total_sequence_length - 1 for GQA convention - seqlens_k[batch_idx] = seq_len - 1; + // seqlens_k output: seq_len + seqlen_offset + // Decode with past (seqlen_offset=-kv_seq_len): pre-append cache count + // Prompt/MEA (seqlen_offset=0): actual token count + // Clamp to 0: all-false mask (seq_len=0) with negative decode offset + // would produce negative seqlens_k, which is undefined in Flash kernels. + seqlens_k[batch_idx] = max(0, seq_len + seqlen_offset); } -Status LaunchConvertMaskToSeqlensK( +// Convert boolean mask to sequence lengths with a configurable offset. +// seqlens_k[b] = num_true_tokens + seqlen_offset +Status LaunchConvertMaskToFlashSeqlensK( const bool* attn_mask_bool, int* seqlens_k, int batch_size, @@ -107,7 +118,8 @@ Status LaunchConvertMaskToSeqlensK( int64_t mask_dim1, int64_t mask_dim2, cudaStream_t stream, - int max_threads_per_block) { + int max_threads_per_block, + int seqlen_offset) { if (batch_size == 0 || total_seq_len == 0) { return Status::OK(); } @@ -123,7 +135,8 @@ Status LaunchConvertMaskToSeqlensK( mask_dims, mask_dim0, mask_dim1, - mask_dim2); + mask_dim2, + seqlen_offset); return CUDA_CALL(cudaGetLastError()); } @@ -173,5 +186,165 @@ template Status LaunchConvertBoolMaskToAttentionBias<__half>( template Status LaunchConvertBoolMaskToAttentionBias<__nv_bfloat16>( const bool*, __nv_bfloat16*, int64_t, float, cudaStream_t, int); +// Convert nonpad_kv_seqlen (int64) to seqlens_k (int32) as actual token count. +// Flash attention's mha_fwd_kvcache expects seqlens_k_ = number of valid tokens. +__global__ void ConvertNonpadKvSeqlenToFlashSeqlensKKernel( + const int64_t* __restrict__ nonpad_kv_seqlen, + int* __restrict__ seqlens_k, + const int batch_size, + const int total_sequence_length) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < batch_size) { + int64_t val = nonpad_kv_seqlen[idx]; + CUDA_KERNEL_ASSERT(val >= 0); + CUDA_KERNEL_ASSERT(val <= static_cast(total_sequence_length)); + val = max(static_cast(0), min(val, static_cast(total_sequence_length))); + seqlens_k[idx] = static_cast(val); // count, not index + } +} + +Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( + const int64_t* nonpad_kv_seqlen, + int* seqlens_k, + int batch_size, + int total_sequence_length, + cudaStream_t stream, + int max_threads_per_block) { + if (batch_size == 0) { + return Status::OK(); + } + + int threads = std::min(batch_size, max_threads_per_block); + int blocks = (batch_size + threads - 1) / threads; + + ConvertNonpadKvSeqlenToFlashSeqlensKKernel<<>>( + nonpad_kv_seqlen, seqlens_k, batch_size, total_sequence_length); + + return CUDA_CALL(cudaGetLastError()); +} + +// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias. +// Generates (batch_size, q_seq_len, total_seq_len) output where: +// position t < nonpad_kv_seqlen[b] → 0.0 (attend) +// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) +template +__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel( + const int64_t* __restrict__ nonpad_kv_seqlen, + T* __restrict__ attention_bias, + const int batch_size, + const int q_seq_len, + const int total_seq_len, + const float mask_filter_value) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; + for (; idx < total; idx += static_cast(gridDim.x) * blockDim.x) { + int b = static_cast(idx / (static_cast(q_seq_len) * total_seq_len)); + int t = static_cast(idx % total_seq_len); + int64_t valid_len = nonpad_kv_seqlen[b]; + CUDA_KERNEL_ASSERT(valid_len >= 0 && valid_len <= static_cast(total_seq_len)); + valid_len = max(static_cast(0), min(valid_len, static_cast(total_seq_len))); + attention_bias[idx] = (t < static_cast(valid_len)) ? T(0.0f) : T(mask_filter_value); + } +} + +template +Status LaunchConvertNonpadKvSeqlenToAttentionBias( + const int64_t* nonpad_kv_seqlen, + T* attention_bias, + int batch_size, + int q_seq_len, + int total_seq_len, + float mask_filter_value, + cudaStream_t stream, + int max_threads_per_block) { + int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; + if (total == 0) { + return Status::OK(); + } + + int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); + int64_t blocks = (total + threads - 1) / threads; + constexpr int64_t kMaxGridDimX = 65535; + unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + + ConvertNonpadKvSeqlenToAttentionBiasKernel<<>>( + nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchConvertNonpadKvSeqlenToAttentionBias( + const int64_t*, float*, int, int, int, float, cudaStream_t, int); +template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>( + const int64_t*, __half*, int, int, int, float, cudaStream_t, int); +template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>( + const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int); + +// Add an addend bias into an existing bias buffer using cyclic broadcasting. +// Used to compose nonpad_kv_seqlen bias [B, q, t] with an attn_mask bias that +// may be smaller (e.g. 2D [q, t] broadcasts over batch). +template +__global__ void AddBiasInPlaceKernel( + T* __restrict__ bias, + const T* __restrict__ addend, + int64_t total_elements, + int64_t addend_elements) { + for (int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < total_elements; + idx += static_cast(gridDim.x) * blockDim.x) { + float sum = static_cast(bias[idx]) + static_cast(addend[idx % addend_elements]); + bias[idx] = T(sum); + } +} + +template +Status LaunchAddBiasInPlace( + T* bias, + const T* addend, + int64_t total_elements, + int64_t addend_elements, + cudaStream_t stream, + int max_threads_per_block) { + if (total_elements == 0 || addend_elements == 0) { + return Status::OK(); + } + + int threads = static_cast(std::min(static_cast(max_threads_per_block), total_elements)); + int64_t blocks = (total_elements + threads - 1) / threads; + constexpr int64_t kMaxGridDimX = 65535; + unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + + AddBiasInPlaceKernel<<>>( + bias, addend, total_elements, addend_elements); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasInPlace(float*, const float*, int64_t, int64_t, cudaStream_t, int); +template Status LaunchAddBiasInPlace<__half>(__half*, const __half*, int64_t, int64_t, cudaStream_t, int); +template Status LaunchAddBiasInPlace<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, int64_t, int64_t, cudaStream_t, int); + +// Simple kernel to fill an int32 buffer with a constant value on device. +// Used for CUDA-graph-capturable seqlens_k initialization (no host memory). +__global__ void FillInt32Kernel(int* __restrict__ output, const int value, const int count) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < count) { + output[idx] = value; + } +} + +Status LaunchFillInt32(int* output, int value, int count, cudaStream_t stream, int max_threads_per_block) { + if (count == 0) { + return Status::OK(); + } + + int threads = std::min(count, max_threads_per_block); + int blocks = (count + threads - 1) / threads; + + FillInt32Kernel<<>>(output, value, count); + + return CUDA_CALL(cudaGetLastError()); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index e2f5383de05f1..48b0894ad5697 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -8,37 +8,30 @@ namespace onnxruntime { namespace cuda { -// Convert a boolean attention mask to sequence lengths for use with GQA kernels. +// Convert a boolean attention mask to sequence lengths with a configurable offset. // // The mask is expected to have the following properties: // 1. It represents right-padding only (valid tokens first, padding at the end) -// 2. Each batch's mask should start with True (valid) values +// 2. All-false masks (zero-length sequence) are valid; otherwise mask should start with True // 3. True values should be contiguous, followed by contiguous False (padding) values // 4. The mask must be broadcastable to (batch_size, num_heads, q_seq_len, total_seq_len) // -// For 2D mask (batch_size, total_seq_len): uses the mask directly per batch +// For 2D mask (q_seq_len, total_seq_len): broadcasts over batch; uses first query position (row 0) // For 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts across batches, uses first head/q // For 4D mask (B, H, q_seq_len, total_seq_len): uses first head, first q position // -// Parameters: -// attn_mask_bool: Input boolean mask on GPU (True = valid, False = padding) -// seqlens_k: Output buffer for sequence lengths (seqlen - 1 for GQA convention) -// batch_size: Number of batches -// total_seq_len: Total sequence length (last dimension of mask) -// mask_dims: Number of dimensions in the mask (2, 3, or 4) -// mask_dim0: First dimension of mask (batch_size for 2D, num_heads for 3D, batch_size for 4D) -// mask_dim1: Second dimension (0 for 2D, q_seq_len for 3D, num_heads for 4D) -// mask_dim2: Third dimension (0 for 2D/3D, q_seq_len for 4D) -// stream: CUDA stream -// max_threads_per_block: Maximum threads per block +// seqlen_offset adjusts the raw token count: +// seqlens_k[b] = num_true_tokens + seqlen_offset // -// Returns: -// Status::OK() on success +// Common offsets: +// 0: actual token count (for prompt with mha_fwd_kvcache, MEA custom right padding) +// -N: subtract N from count (for decode with mha_fwd_kvcache where N=kv_sequence_length, +// giving the number of tokens already in cache BEFORE appending new ones) // -// Note: Mask validity (right-padding convention, starts with True, contiguous True/False) +// Note: Mask validity (right-padding convention, contiguous True/False) // is checked asynchronously via CUDA_KERNEL_ASSERT inside the kernel. Invalid masks will // trigger a device-side assertion failure. -Status LaunchConvertMaskToSeqlensK( +Status LaunchConvertMaskToFlashSeqlensK( const bool* attn_mask_bool, int* seqlens_k, int batch_size, @@ -48,7 +41,8 @@ Status LaunchConvertMaskToSeqlensK( int64_t mask_dim1, int64_t mask_dim2, cudaStream_t stream, - int max_threads_per_block); + int max_threads_per_block, + int seqlen_offset = 0); // Convert a boolean attention mask to an additive attention bias for the MHA path. // Maps true -> 0.0 (attend) and false -> mask_filter_value (mask out). @@ -62,5 +56,48 @@ Status LaunchConvertBoolMaskToAttentionBias( cudaStream_t stream, int max_threads_per_block); +// Convert nonpad_kv_seqlen (int64, per-batch valid KV lengths) to seqlens_k (int32) +// as actual token count. Flash attention's mha_fwd_kvcache expects seqlens_k_ = number +// of valid tokens. +Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( + const int64_t* nonpad_kv_seqlen, + int* seqlens_k, + int batch_size, + int total_sequence_length, + cudaStream_t stream, + int max_threads_per_block); + +// Convert nonpad_kv_seqlen to an additive attention bias for the MHA unfused path. +// Generates a (batch_size, q_seq_len, total_seq_len) tensor where: +// position t < nonpad_kv_seqlen[b] → 0.0 (attend) +// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) +template +Status LaunchConvertNonpadKvSeqlenToAttentionBias( + const int64_t* nonpad_kv_seqlen, + T* attention_bias, + int batch_size, + int q_seq_len, + int total_seq_len, + float mask_filter_value, + cudaStream_t stream, + int max_threads_per_block); + +// Additively compose an addend bias into an existing bias buffer in-place. +// Supports cyclic broadcasting: addend of size [q, t] is repeated over batch +// to compose with a bias of size [B, q, t]. When both have the same number +// of elements (e.g. 4D mask [B, 1, q, t]), it performs a direct element-wise add. +template +Status LaunchAddBiasInPlace( + T* bias, + const T* addend, + int64_t total_elements, + int64_t addend_elements, + cudaStream_t stream, + int max_threads_per_block); + +// Fill an int32 buffer with a constant value entirely on device. +// CUDA-graph-capturable alternative to host vector + cudaMemcpyAsync. +Status LaunchFillInt32(int* output, int value, int count, cudaStream_t stream, int max_threads_per_block); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index b0c6c6d801c4b..6c4d09c9b96e9 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -481,6 +481,95 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { ); } +// Regression test: all-false bool mask in decode mode (past_sequence_length > 0). +// Before the fix: seq_len=0 + negative seqlen_offset produced negative seqlens_k → UB/crash. +// After the fix: clamped to max(0, ...) → uniform softmax → mean of V values. +TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { + int batch_size = 1; + int q_num_heads = 2; + int q_sequence_length = 1; // decode: single token + int head_size = 8; + int kv_sequence_length = 1; // appending 1 new token + int kv_num_heads = 2; + int v_head_size = 8; + int past_sequence_length = 3; // 3 tokens in cache → total_seq = 4 + + // Q: [1, 2, 1, 8] — values don't matter for uniform softmax test + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.5f); + + // K: [1, 2, 1, 8] + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.5f); + + // V: [1, 2, 1, 8] — new token values (will be concatenated with past_value) + std::vector v = { + // head 0 + 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, + // head 1 + 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f}; + + // past_key: [1, 2, 3, 8] + std::vector past_key(batch_size * kv_num_heads * past_sequence_length * head_size, 0.5f); + + // past_value: [1, 2, 3, 8] — distinct per-row values so mean is meaningful + std::vector past_value = { + // head 0: 3 past positions + 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, + 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, + // head 1: 3 past positions + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, + 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f}; + + // Bool mask: [1, 4] — all false (every position masked) + std::initializer_list m = {false, false, false, false}; + + // present_key = concat(past_key, k) along seq dim → [1, 2, 4, 8] + // past_key is all 0.5, new k is all 0.5 → present_key is all 0.5 + int total_sequence_length = past_sequence_length + kv_sequence_length; + std::vector present_key(batch_size * kv_num_heads * total_sequence_length * head_size, 0.5f); + + // present_value = concat(past_value, v) along seq dim → [1, 2, 4, 8] + std::vector present_value = { + // head 0: past rows (0.1, 0.2, 0.3) + new row (0.4) + 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, + 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, + 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, 0.4f, + // head 1: past rows (0.5, 0.6, 0.7) + new row (0.8) + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, 0.6f, + 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, 0.7f, + 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f}; + + // With all-false mask, softmax produces uniform weights: 1/4 per position. + // Output = mean of V rows (past_value concat new_v): + // head 0: mean(0.1, 0.2, 0.3, 0.4) = 0.25 + // head 1: mean(0.5, 0.6, 0.7, 0.8) = 0.65 + std::vector y = { + // head 0 + 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, + // head 1 + 0.65f, 0.65f, 0.65f, 0.65f, 0.65f, 0.65f, 0.65f, 0.65f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * total_sequence_length); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * total_sequence_length * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * total_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), m, past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, + y, present_key, present_value, std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + TEST(AttentionTest, Attention4DDefaultFloat16) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -1532,10 +1621,141 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_AllMasked) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); + if (HasCudaEnvironment(0)) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Edge case: nonpad_kv_seqlen=0 exercised on CUDA Flash/MEA path with fp16 and GQA. +// This verifies the val >= 0 assertion in ConvertNonpadKvSeqlenToFlashSeqlensKKernel +// and ConvertNonpadKvSeqlenToAttentionBiasKernel. +TEST(AttentionTest, Attention_NonPadKVSeqLen_AllMasked_FP16_GQA) { + if (!HasCudaEnvironment(530)) { + return; // fp16 requires SM 5.3+ + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + // batch=2: batch 0 is fully masked (nonpad=0), batch 1 has 2 valid positions (nonpad=2). + // GQA: q_num_heads=4, kv_num_heads=2 → triggers Flash/MEA GQA path. + // 4D BNSH: [B, N, S, H] + int batch_size = 2; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 1; + int kv_sequence_length = 4; + int head_size = 64; + + int q_elements = batch_size * q_num_heads * q_sequence_length * head_size; + int k_elements = batch_size * kv_num_heads * kv_sequence_length * head_size; + int v_elements = k_elements; + + // Use constant values for predictable uniform-softmax results. + std::vector q(q_elements, 1.0f); + std::vector k(k_elements, 1.0f); + // V: each row = row_index * 0.1 for distinct values across positions. + std::vector v(v_elements); + for (int b = 0; b < batch_size; b++) { + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(b * kv_num_heads * kv_sequence_length + n * kv_sequence_length + s) * head_size + h] = val; + } + } + } + } + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(v)); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + // batch 0: nonpad=0 (all masked), batch 1: nonpad=2 (first 2 of 4 positions valid) + test.AddInput("nonpad_kv_seqlen", {batch_size}, {0, 2}); + + // Expected output shape: [B, q_num_heads, q_seq, head_size] + // Batch 0 (all masked, nonpad=0): Flash/MEA returns zeros for fully-masked sequences. + // Batch 1 (2 valid): softmax over positions 0..1 (equal K, so uniform) + // mean of first 2 rows = (0.1 + 0.2) / 2 = 0.15 for each head dim + int y_elements = batch_size * q_num_heads * q_sequence_length * head_size; + std::vector expected_y(y_elements, 0.0f); // batch 0 = zeros + for (int n = 0; n < q_num_heads; n++) { + for (int h = 0; h < head_size; h++) { + expected_y[(1 * q_num_heads + n) * head_size + h] = 0.15f; // batch 1 + } + } + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + ToFloat16(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Regression test: BFloat16 must route to Flash Attention on SM80+. +// BFloat16 is a 2-byte type so disable_flash_attention_ should be false, +// allowing Flash to handle GQA natively via kv_num_heads. +TEST(AttentionTest, Attention_NonPadKVSeqLen_BF16_Flash) { + if (!HasCudaEnvironment(800)) { + return; // BFloat16 requires SM 8.0+ + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 1; + int kv_sequence_length = 4; + int head_size = 64; + + int q_elements = batch_size * q_num_heads * q_sequence_length * head_size; + int k_elements = batch_size * kv_num_heads * kv_sequence_length * head_size; + + std::vector q(q_elements, 1.0f); + std::vector k(k_elements, 1.0f); + std::vector v(k_elements); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(n * kv_sequence_length + s) * head_size + h] = val; + } + } + } + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, FloatsToBFloat16s(q)); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, FloatsToBFloat16s(k)); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, FloatsToBFloat16s(v)); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("nonpad_kv_seqlen", {batch_size}, {2}); + + // 2 valid positions: uniform softmax → mean of V[0] and V[1] = (0.1 + 0.2) / 2 = 0.15 + int y_elements = batch_size * q_num_heads * q_sequence_length * head_size; + std::vector expected_y(y_elements, 0.15f); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + FloatsToBFloat16s(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// Edge case: nonpad_kv_seqlen = total_sequence_length (no positions masked). TEST(AttentionTest, Attention_NonPadKVSeqLen_NoneMasked) { OpTester test("Attention", 24, onnxruntime::kOnnxDomain); @@ -1620,5 +1840,102 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_ExceedsTotalSeqLen) { {}, nullptr, &execution_providers); } +// Test combined nonpad_kv_seqlen + bool attn_mask. +// Both masks should compose additively: nonpad_kv_seqlen masks positions >= valid_len, +// and attn_mask further masks positions within the valid range. +// Previously this combination crashed with ORT_ENFORCE; now it falls back to MEA gracefully. +TEST(AttentionTest, Attention_NonPadKVSeqLen_WithBoolAttnMask) { + // batch_size=1, q_num_heads=1, kv_num_heads=1 + // q_seq_len=1, kv_seq_len=4, head_size=2 + // nonpad_kv_seqlen=[3] => positions 0,1,2 valid by padding + // attn_mask=[true, false, true, true] => position 1 masked by attn_mask + // Combined: only positions 0 and 2 are effective + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + // Q and K designed for uniform attention scores on valid positions + std::vector q = {1.0f, 1.0f}; + std::vector k(8, 1.0f); + // V: [10, 20], [30, 40], [50, 60], [70, 80] + std::vector v = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f}; + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + // 2D bool mask [1, 4]: position 1 is false (masked out) + test.AddInput("attn_mask", {1, 4}, {true, false, true, true}); + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {1}, {3}); + + // nonpad masks position 3 (>= valid_len=3), attn_mask masks position 1 + // Active positions: 0 and 2, with uniform attention scores + // Output = mean(V[0], V[2]) = [(10+50)/2, (20+60)/2] = [30.0, 40.0] + std::vector expected_y = {30.0f, 40.0f}; + test.AddOutput("Y", {1, 1, 1, 2}, expected_y, false, 0, 1e-3f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + if (HasCudaEnvironment(0)) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test combined nonpad_kv_seqlen + float attn_mask with multi-batch. +// Verifies additive composition works with float attention bias values. +TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) { + // batch_size=2, q_num_heads=1, kv_num_heads=1 + // q_seq_len=1, kv_seq_len=4, head_size=2 + // nonpad_kv_seqlen=[3, 2] => batch 0 has 3 valid, batch 1 has 2 valid + // 2D float attn_mask [1, 4] = [0, -inf, 0, 0] => masks position 1 for all batches + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {2, 1, 1, 2}; + std::vector k_shape = {2, 1, 4, 2}; + std::vector v_shape = {2, 1, 4, 2}; + + std::vector q = {1.0f, 1.0f, 1.0f, 1.0f}; + std::vector k(16, 1.0f); + // Batch 0 V: [1,1], [2,2], [3,3], [99,99] + // Batch 1 V: [5,5], [6,6], [99,99], [99,99] + std::vector v = { + 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 99.0f, 99.0f, // batch 0 + 5.0f, 5.0f, 6.0f, 6.0f, 99.0f, 99.0f, 99.0f, 99.0f // batch 1 + }; + + float neg_inf = -std::numeric_limits::infinity(); + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + // 2D float mask [1, 4]: position 1 gets -inf + test.AddInput("attn_mask", {1, 4}, {0.0f, neg_inf, 0.0f, 0.0f}); + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {2}, {3, 2}); + + // Batch 0: nonpad masks pos 3, attn_mask masks pos 1 → active: 0,2 + // Output = mean(V[0], V[2]) = [(1+3)/2, (1+3)/2] = [2.0, 2.0] + // Batch 1: nonpad masks pos 2,3, attn_mask masks pos 1 → active: 0 only + // Output = V[0] = [5.0, 5.0] + std::vector expected_y = {2.0f, 2.0f, 5.0f, 5.0f}; + test.AddOutput("Y", {2, 1, 1, 2}, expected_y, false, 0, 1e-3f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + if (HasCudaEnvironment(0)) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 10a38329549a8..48640fa38aca2 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -95,6 +95,10 @@ class AttentionConfig: has_attn_mask: bool = False attn_mask_dims: int = 2 # 2D, 3D, or 4D boolean mask attn_mask_type: str = "bool" # "bool" for GQA path, "additive" for MHA path + has_nonpad_kv_seqlen: bool = False # Opset 24: nonpad_kv_seqlen input + use_4d_bnsh: bool = False # Use 4D [B, num_heads, seq, head_size] inputs instead of 3D + broadcast_mask_batch: bool = False # Use batch dim 1 in 4D mask for broadcasting + broadcast_mask_heads: bool = False # Use heads dim 1 in 4D mask for broadcasting # ################################################################################################# @@ -157,6 +161,7 @@ def create_attention_node_and_io( "attn_mask" if config.has_attn_mask else "", "past_key" if is_past else "", "past_value" if is_past else "", + "nonpad_kv_seqlen" if config.has_nonpad_kv_seqlen else "", ] # Remove trailing empty strings @@ -178,17 +183,40 @@ def create_attention_node_and_io( ) # --- Graph Inputs --- - # ONNX Attention op uses 3D inputs: [batch, seq_len, hidden_size] - q_hidden_size = config.q_num_heads * config.head_size - kv_hidden_size = config.kv_num_heads * config.head_size - - graph_input = [ - helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), - helper.make_tensor_value_info("key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size]), - helper.make_tensor_value_info( - "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] - ), - ] + if config.use_4d_bnsh: + # 4D BNSH inputs: [batch, num_heads, seq_len, head_size] + graph_input = [ + helper.make_tensor_value_info( + "query", + ort_type, + [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size], + ), + helper.make_tensor_value_info( + "key", + ort_type, + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + ), + helper.make_tensor_value_info( + "value", + ort_type, + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + ), + ] + else: + # 3D inputs: [batch, seq_len, hidden_size] + q_hidden_size = config.q_num_heads * config.head_size + kv_hidden_size = config.kv_num_heads * config.head_size + graph_input = [ + helper.make_tensor_value_info( + "query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size] + ), + helper.make_tensor_value_info( + "key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + ), + helper.make_tensor_value_info( + "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + ), + ] if isinstance(config.kv_cache_type, torch.dtype): cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] @@ -206,18 +234,19 @@ def create_attention_node_and_io( mask_ort_type = ort_type # additive mask uses same type as Q/K/V # Mask shapes differ between GQA (bool) and MHA (additive/bool) paths: - # GQA bool: 2D=[batch, total_seq] — GQA converts to seqlens_k directly, bypassing ONNX broadcasting. - # MHA (additive or bool): 2D=[q_seq, total_seq] — follows ONNX right-aligned broadcasting. + # Per ONNX spec, 2D mask is [q_seq, total_seq] for all paths. # 3D and 4D are the same for both paths. # ONNX broadcasting aligns from the right: 3D [A, B, C] → [_, A, B, C] where A=heads is_gqa = config.kv_num_heads != config.q_num_heads if config.attn_mask_type == "bool" and is_gqa: if config.attn_mask_dims == 2: - mask_shape = [config.batch_size, mask_seq_len] + mask_shape = [config.q_sequence_length, mask_seq_len] elif config.attn_mask_dims == 3: mask_shape = [config.q_num_heads, config.q_sequence_length, mask_seq_len] else: # 4D - mask_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, mask_seq_len] + mask_batch = 1 if config.broadcast_mask_batch else config.batch_size + mask_heads = 1 if config.broadcast_mask_heads else config.q_num_heads + mask_shape = [mask_batch, mask_heads, config.q_sequence_length, mask_seq_len] else: # additive, or bool on MHA path if config.attn_mask_dims == 2: mask_shape = [config.q_sequence_length, mask_seq_len] @@ -225,7 +254,9 @@ def create_attention_node_and_io( # 3D aligns to [_, heads, q_seq, total_seq] — dim 0 must be 1 or num_heads mask_shape = [config.q_num_heads, config.q_sequence_length, mask_seq_len] else: # 4D - mask_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, mask_seq_len] + mask_batch = 1 if config.broadcast_mask_batch else config.batch_size + mask_heads = 1 if config.broadcast_mask_heads else config.q_num_heads + mask_shape = [mask_batch, mask_heads, config.q_sequence_length, mask_seq_len] graph_input.append(helper.make_tensor_value_info("attn_mask", mask_ort_type, mask_shape)) # past_key and past_value for ONNX Attention op @@ -239,13 +270,20 @@ def create_attention_node_and_io( ] ) + # nonpad_kv_seqlen for Opset 24: int64 tensor [batch_size] + if config.has_nonpad_kv_seqlen: + graph_input.append(helper.make_tensor_value_info("nonpad_kv_seqlen", TensorProto.INT64, [config.batch_size])) + # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + if config.use_4d_bnsh: + output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size] + else: + output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + graph_output = [ - helper.make_tensor_value_info( - "output", ort_type, [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] - ), + helper.make_tensor_value_info("output", ort_type, output_shape), helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), ] @@ -262,11 +300,17 @@ def create_attention_node_and_io( return node, graph_input, graph_output +def _get_opset_version(config: AttentionConfig): + """Return 24 when nonpad_kv_seqlen is used, otherwise 23.""" + return 24 if config.has_nonpad_kv_seqlen else 23 + + def create_attention_graph_prompt(config: AttentionConfig, ort_type): """Create ONNX graph for prompt phase (no past KV cache).""" node, graph_input, graph_output = create_attention_node_and_io(config, ort_type, is_past=False) graph = helper.make_graph([node], "Attention_Graph", graph_input, graph_output) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 23)]) + opset = _get_opset_version(config) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) return model.SerializeToString() @@ -274,7 +318,8 @@ def create_attention_graph_past(config: AttentionConfig, ort_type): """Create ONNX graph for decoding phase (with past KV cache).""" node, graph_input, graph_output = create_attention_node_and_io(config, ort_type, is_past=True) graph = helper.make_graph([node], "Attention_Graph", graph_input, graph_output) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 23)]) + opset = _get_opset_version(config) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) return model.SerializeToString() @@ -338,6 +383,7 @@ def attention_prompt_func( ep, device, ort_type=TensorProto.FLOAT16, + nonpad_kv_seqlen=None, ): """ Run ONNX Attention op for prompt phase (no past KV cache). @@ -351,6 +397,7 @@ def attention_prompt_func( ep: Execution provider (e.g., "CUDAExecutionProvider") device: Device string (e.g., "cuda") ort_type: ONNX tensor type + nonpad_kv_seqlen: Optional int64 tensor [batch_size] for opset 24 """ if not config.kv_cache_type: config.kv_cache_type = { @@ -364,30 +411,55 @@ def attention_prompt_func( ort_type=ort_type, ) - # Reshape to 3D [batch, seq_len, hidden_size] - q_3d = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) - k_3d = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) - v_3d = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + # Reshape inputs for ONNX graph + if config.use_4d_bnsh: + # 4D BNSH: [batch, num_heads, seq_len, head_size] + q_input = q.transpose(1, 2).contiguous() # BSNH → BNSH + k_input = k.transpose(1, 2).contiguous() + v_input = v.transpose(1, 2).contiguous() + else: + # 3D: [batch, seq_len, hidden_size] + q_input = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + k_input = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v_input = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) io_binding = ort_session.io_binding() # Bind inputs - bind_tensor(io_binding, "query", q_3d, device, ort_type) - bind_tensor(io_binding, "key", k_3d, device, ort_type) - bind_tensor(io_binding, "value", v_3d, device, ort_type) + bind_tensor(io_binding, "query", q_input, device, ort_type) + bind_tensor(io_binding, "key", k_input, device, ort_type) + bind_tensor(io_binding, "value", v_input, device, ort_type) # Bind optional attention mask if config.has_attn_mask and attn_mask is not None: mask_ort_type = _get_mask_ort_type(config, ort_type) bind_tensor(io_binding, "attn_mask", attn_mask, device, mask_ort_type) + # Bind optional nonpad_kv_seqlen (opset 24+) + if config.has_nonpad_kv_seqlen: + if nonpad_kv_seqlen is None: + raise ValueError( + "Invariant violated: config.has_nonpad_kv_seqlen=True but the nonpad_kv_seqlen " + "tensor is None. Either provide the tensor or set has_nonpad_kv_seqlen=False." + ) + bind_tensor(io_binding, "nonpad_kv_seqlen", nonpad_kv_seqlen, device, TensorProto.INT64) + # Bind Outputs hidden_size = config.q_num_heads * config.head_size out_dtype = _get_out_dtype(ort_type) - out_torch = torch.zeros((config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device) + if config.use_4d_bnsh: + out_torch = torch.zeros( + (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + dtype=out_dtype, + device=device, + ) + else: + out_torch = torch.zeros( + (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape for prompt (no past) @@ -450,10 +522,17 @@ def attention_past_func( ort_type=ort_type, ) - # Reshape to 3D [batch, seq_len, hidden_size] - q_3d = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) - new_k_3d = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v_3d = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + # Reshape inputs for ONNX graph + if config.use_4d_bnsh: + # 4D BNSH: [batch, num_heads, seq_len, head_size] + q_input = q.transpose(1, 2).contiguous() # BSNH → BNSH + new_k_input = new_k.transpose(1, 2).contiguous() + new_v_input = new_v.transpose(1, 2).contiguous() + else: + # 3D: [batch, seq_len, hidden_size] + q_input = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + new_k_input = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v_input = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) @@ -463,9 +542,9 @@ def attention_past_func( total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length # Bind inputs - bind_tensor(io_binding, "query", q_3d, device, ort_type) - bind_tensor(io_binding, "key", new_k_3d, device, ort_type) - bind_tensor(io_binding, "value", new_v_3d, device, ort_type) + bind_tensor(io_binding, "query", q_input, device, ort_type) + bind_tensor(io_binding, "key", new_k_input, device, ort_type) + bind_tensor(io_binding, "value", new_v_input, device, ort_type) # Bind optional attention mask if config.has_attn_mask and attn_mask is not None: @@ -489,7 +568,16 @@ def attention_past_func( hidden_size = config.q_num_heads * config.head_size out_dtype = _get_out_dtype(ort_type) - out_torch = torch.zeros((config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device) + if config.use_4d_bnsh: + out_torch = torch.zeros( + (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + dtype=out_dtype, + device=device, + ) + else: + out_torch = torch.zeros( + (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape (past + new) @@ -651,7 +739,7 @@ def create_boolean_mask_from_seqlens( Returns: Boolean mask where True = valid, False = padding. - - 2D: [batch_size, total_seq_len] - broadcasts to [batch, 1, 1, total_seq] + - 2D: [q_seq_len, total_seq_len] - broadcasts to [batch, heads, q_seq, total_seq] - 3D: [num_heads, q_seq_len, total_seq_len] - broadcasts to [1, num_heads, q_seq, total_seq] - 4D: [batch_size, num_heads, q_seq_len, total_seq_len] - no broadcasting """ @@ -664,7 +752,10 @@ def create_boolean_mask_from_seqlens( mask_2d = arange < seqlens_expanded # [batch_size, total_seq_len] if mask_dims == 2: - return mask_2d + # 2D: [q_seq_len, total_seq_len] per ONNX spec. Broadcasts across batch and heads. + # Since 2D masks can't vary per batch, use the first batch's pattern. + mask_1d = mask_2d[0:1, :] # [1, total_seq_len] + return mask_1d.expand(q_seq_len, total_seq_len).contiguous() elif mask_dims == 3: # 3D mask: [num_heads, q_seq_len, total_seq_len] # For right-padding tests, all batches should have the same mask pattern per position. diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index d6a9246f7b792..15cc5e3c27803 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -24,6 +24,7 @@ import os import unittest +from unittest.mock import patch import numpy import torch @@ -36,6 +37,7 @@ attention_past_func, attention_prompt_func, attention_ref, + create_additive_mask_from_seqlens, create_boolean_mask_from_seqlens, enable_debug_print, enable_deterministic_check, @@ -152,7 +154,13 @@ def parity_check_gqa_prompt( ) torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + if config.use_4d_bnsh: + # Torch SDPA outputs [B, num_heads, q_seq, head_size] (BNSH format). + # For 4D BNSH test configs, transpose to [B, q_seq, num_heads, head_size] (BSNH) + # to match ORT's 3D output convention for comparison. + out = out.transpose(1, 2).contiguous() + else: + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -315,7 +323,10 @@ def parity_check_gqa_past( present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + if config.use_4d_bnsh: + out = out.transpose(1, 2).contiguous() + else: + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() if enable_debug_print: @@ -388,9 +399,17 @@ def parity_check_gqa_prompt_with_padding( ) v = torch.randn_like(k) * std + # 2D and 3D masks broadcast across batches (no per-batch dimension), so they can only + # represent one padding pattern. The mask uses batch 0's seqlen for all batches. + # Adjust effective_seqlens so the reference comparison matches the actual mask. + if config.attn_mask_dims in (2, 3): + effective_seqlens = torch.full_like(seqlens, seqlens[0].item()) + else: + effective_seqlens = seqlens + # Zero out padded positions in K, V for proper comparison for b in range(config.batch_size): - valid_len = seqlens[b].item() + valid_len = effective_seqlens[b].item() if valid_len < config.kv_sequence_length: k[b, valid_len:, :, :] = 0 v[b, valid_len:, :, :] = 0 @@ -404,12 +423,12 @@ def parity_check_gqa_prompt_with_padding( device=device, ) - key_padding_mask = create_boolean_mask_from_seqlens( - seqlens=seqlens, - total_seq_len=config.kv_sequence_length, - mask_dims=2, - device=device, - ) + # Per-batch key_padding_mask [batch, kv_seq] for reference. + # Must NOT use create_boolean_mask_from_seqlens(..., mask_dims=2) here because that + # returns [q_seq, total_seq] using only the first batch's seqlen, which is wrong + # when effective_seqlens vary per batch (4D mask case). + arange_kv = torch.arange(config.kv_sequence_length, device=device).unsqueeze(0) + key_padding_mask = arange_kv < effective_seqlens.unsqueeze(1) # [batch, kv_seq] # --- PyTorch Reference Path --- out_ref, _ = attention_ref( @@ -437,7 +456,7 @@ def parity_check_gqa_prompt_with_padding( # --- Comparison --- for b in range(config.batch_size): - valid_len = seqlens[b].item() + valid_len = effective_seqlens[b].item() if valid_len < config.q_sequence_length: out[b, valid_len:, :, :] = 0 out_ref[b, valid_len:, :, :] = 0 @@ -658,6 +677,8 @@ def gqa_prompt_padding_test_cases(): Generate test cases for ONNX Attention op GQA path with boolean padding masks. Tests 2D, 3D, and 4D boolean masks for right-padding scenarios. + Includes a batch_size=4, q_seq=1 case where batch_size != q_seq_len to + guard against 2D mask shape bugs (must be [q_seq, total_seq] not [batch, total_seq]). """ batches = [2] seqs = [(16, 16)] @@ -685,6 +706,24 @@ def gqa_prompt_padding_test_cases(): name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_mask{mask_dims}d" yield name, config + # Guard case: batch_size=4 != q_seq_len=1 (decode). This catches the original bug + # where 2D mask was [batch, total_seq] instead of [q_seq, total_seq]. + for mask_dims in mask_dims_options: + config = AttentionConfig( + batch_size=4, + q_sequence_length=1, + kv_sequence_length=32, + past_kv_sequence_length=0, + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=mask_dims, + ) + name = f"b4_sq1_skv32_nh8_2_h128_mask{mask_dims}d_shape_guard" + yield name, config + def gqa_past_padding_test_cases(): """ @@ -723,12 +762,15 @@ def gqa_past_padding_test_cases(): @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) class TestONNXAttentionFlashGQA(unittest.TestCase): - """Test ONNX Attention op (opset 23) GQA path with Flash Attention.""" + """Test ONNX Attention op (opset 23) GQA path with Flash Attention. + + Requires SM80+: tests explicitly force Flash via ORT_DISABLE_FLASH_ATTENTION=0. + """ @parameterized.expand(gqa_prompt_test_cases()) def test_gqa_prompt_flash(self, name, config): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config=config, ep="CUDAExecutionProvider", @@ -742,7 +784,6 @@ def test_gqa_prompt_flash(self, name, config): @parameterized.expand(gqa_past_test_cases()) def test_gqa_past_flash(self, name, config): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config=config, ep="CUDAExecutionProvider", @@ -756,8 +797,13 @@ def test_gqa_past_flash(self, name, config): @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) class TestONNXAttentionFlashGQABF16(unittest.TestCase): - """Test ONNX Attention op (opset 23) GQA path with Flash Attention using BFloat16.""" + """Test ONNX Attention op (opset 23) GQA path with Flash Attention using BFloat16. + + Requires SM80+: tests explicitly force Flash via ORT_DISABLE_FLASH_ATTENTION=0, + and BFloat16 requires Ampere or higher. + """ @parameterized.expand(gqa_prompt_test_cases()) def test_gqa_prompt_flash_bf16(self, name, config): @@ -765,7 +811,6 @@ def test_gqa_prompt_flash_bf16(self, name, config): self.skipTest("BFloat16 not supported on this device") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config=config, ep="CUDAExecutionProvider", @@ -783,7 +828,6 @@ def test_gqa_past_flash_bf16(self, name, config): self.skipTest("BFloat16 not supported on this device") config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config=config, ep="CUDAExecutionProvider", @@ -797,12 +841,12 @@ def test_gqa_past_flash_bf16(self, name, config): @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) class TestONNXAttentionMemoryEfficientGQA(unittest.TestCase): """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention.""" @parameterized.expand(gqa_prompt_test_cases()) def test_gqa_prompt_memory_efficient(self, name, config): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" parity_check_gqa_prompt( config=config, ep="CUDAExecutionProvider", @@ -814,81 +858,32 @@ def test_gqa_prompt_memory_efficient(self, name, config): atol=atol["fp16"], ) - @parameterized.expand(gqa_past_test_cases()) - def test_gqa_past_memory_efficient(self, name, config): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) - - -@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") -class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): - """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" - - @parameterized.expand(gqa_past_test_cases()) - def test_gqa_past_memory_efficient_bf16(self, name, config): - if not torch.cuda.is_bf16_supported(): - self.skipTest("BFloat16 not supported on this device") - - config.kv_cache_type = "bfloat16" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.bfloat16, - ort_type=TensorProto.BFLOAT16, - causal=True, - rtol=rtol["bf16"], - atol=atol["bf16"], - ) + # Note: GQA past tests removed — MEA is ineligible when past_key is present + # (ComputeInternal requires past_key == nullptr for MEA). GQA past requires + # flash attention. @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) class TestONNXAttentionPaddingMaskGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks. + Requires SM80+: decode+padding tests explicitly force Flash via + ORT_DISABLE_FLASH_ATTENTION=0. Prompt+padding uses MEA fallback and is + tested separately in TestONNXAttentionPaddingMaskMemoryEfficientGQA. + These tests verify that the boolean attn_mask is correctly converted to sequence lengths on GPU and that the attention computation respects the padding. Tests cover 2D, 3D, and 4D mask shapes. - """ - - @parameterized.expand(gqa_prompt_padding_test_cases()) - def test_gqa_prompt_padding_flash(self, name, config): - """Test prompt phase with padding mask using Flash Attention.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - seqlens = torch.tensor( - [config.kv_sequence_length - 6, config.kv_sequence_length], - dtype=torch.int32, - device="cuda", - ) - - parity_check_gqa_prompt_with_padding( - config=config, - seqlens=seqlens, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) + Note: prompt+bool_mask is ineligible for flash (routed to MEA), so prompt + padding tests live in TestONNXAttentionPaddingMaskMemoryEfficientGQA only. + """ @parameterized.expand(gqa_past_padding_test_cases()) def test_gqa_past_padding_flash(self, name, config): """Test decoding phase with padding mask using Flash Attention.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - past_seqlens = torch.full( (config.batch_size,), config.past_kv_sequence_length, @@ -909,6 +904,7 @@ def test_gqa_past_padding_flash(self, name, config): @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) class TestONNXAttentionPaddingMaskMemoryEfficientGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks @@ -918,10 +914,11 @@ class TestONNXAttentionPaddingMaskMemoryEfficientGQA(unittest.TestCase): @parameterized.expand(gqa_prompt_padding_test_cases()) def test_gqa_prompt_padding_mea(self, name, config): """Test prompt phase with padding mask using Memory Efficient Attention.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - + # Create seqlens with config.batch_size elements. + # First batch has shorter valid length, rest at full length. + seqlens_list = [config.kv_sequence_length - 6] + [config.kv_sequence_length] * (config.batch_size - 1) seqlens = torch.tensor( - [config.kv_sequence_length - 6, config.kv_sequence_length], + seqlens_list, dtype=torch.int32, device="cuda", ) @@ -937,29 +934,410 @@ def test_gqa_prompt_padding_mea(self, name, config): atol=atol["fp16"], ) - @parameterized.expand(gqa_past_padding_test_cases()) - def test_gqa_past_padding_mea(self, name, config): - """Test decoding phase with padding mask using Memory Efficient Attention.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - past_seqlens = torch.full( - (config.batch_size,), - config.past_kv_sequence_length, - dtype=torch.int32, +# ################################################################################################# +# Parity Check with nonpad_kv_seqlen (Opset 24) +# ################################################################################################# + + +def parity_check_gqa_prompt_with_nonpad_kv_seqlen( + config: AttentionConfig, + nonpad_seqlens: torch.Tensor, + ep, + device, + torch_type, + ort_type, + rtol, + atol, + std=0.2, +): + """ + Parity check for ONNX Attention op (opset 24) GQA path with nonpad_kv_seqlen. + + nonpad_kv_seqlen tells the op how many KV positions per batch are valid. + Positions beyond the valid length are treated as padding and masked out. + Cannot be used together with past_key/past_value. + """ + torch.manual_seed(0) + + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) * std + + # Zero out padded positions in K, V for proper comparison + for b in range(config.batch_size): + valid_len = nonpad_seqlens[b].item() + if valid_len < config.kv_sequence_length: + k[b, valid_len:, :, :] = 0 + v[b, valid_len:, :, :] = 0 + + # Per-batch key_padding_mask [batch, kv_seq] for reference + arange_kv = torch.arange(config.kv_sequence_length, device=device).unsqueeze(0) + key_padding_mask = arange_kv < nonpad_seqlens.unsqueeze(1).to(device) # [batch, kv_seq] + + out_ref, _ = attention_ref( + q=q, + k=k, + v=v, + key_padding_mask=key_padding_mask, + causal=config.is_causal == 1, + softcap=config.softcap, + ) + + # ORT path: use nonpad_kv_seqlen (int64 tensor) + nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) + + out, present_k, present_v = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=None, + ep=ep, + device=device, + ort_type=ort_type, + nonpad_kv_seqlen=nonpad_kv_seqlen_tensor, + ) + + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + + # When nonpad_kv_seqlen=0 for a batch, all KV positions are masked → softmax yields NaN. + # Zero out those batches in both ORT and reference for comparison. + for b in range(config.batch_size): + if nonpad_seqlens[b].item() == 0: + out[b, :, :, :] = 0 + out_ref[b, :, :, :] = 0 + + out_np = out.to(torch.float32).detach().cpu().numpy() + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +def gqa_nonpad_kv_seqlen_test_cases(): + """ + Generate test cases for ONNX Attention op (opset 24) GQA path with nonpad_kv_seqlen. + + In prompt mode (q_seq == kv_seq), the GQA kernel ignores seqlens_k and uses + padded_seq_lens = sequence_length unconditionally. nonpad_kv_seqlen masking is only + meaningful for decode (q_seq != kv_seq), which routes to FlashAttentionForExternalKVCache. + In the real TensorScatter workflow, prompt mode always has all tokens valid, so + nonpad_kv_seqlen = total_kv_sequence_length (mask nothing). + """ + h = 128 + sq = 16 + skv = 16 + n = 8 + n2 = 2 + + # In prompt mode, nonpad_kv_seqlen should equal total_kv_sequence_length (all tokens valid). + # Partial masking (nonpad < kv_sequence_length) is not supported by the GQA kernel in prompt mode. + seqlen_scenarios = [ + (1, [16], "single_batch"), + (2, [16, 16], "full_len"), + (4, [16, 16, 16, 16], "multi_batch"), + ] + + for batch_size, seqlens, label in seqlen_scenarios: + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n2, + head_size=h, + is_causal=1, + has_nonpad_kv_seqlen=True, + ) + name = f"b{batch_size}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_{label}" + yield name, config, seqlens + + +def gqa_nonpad_kv_seqlen_cpu_test_cases(): + """CPU-only test cases including zero_seqlen (triggers CUDA_KERNEL_ASSERT in debug builds).""" + yield from gqa_nonpad_kv_seqlen_test_cases() + + h = 128 + sq = 16 + skv = 16 + n = 8 + n2 = 2 + config = AttentionConfig( + batch_size=2, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n2, + head_size=h, + is_causal=1, + has_nonpad_kv_seqlen=True, + ) + yield f"b2_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_zero_seqlen", config, [0, 16] + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) +class TestONNXAttentionGQANonpadKVSeqlen(unittest.TestCase): + """Test ONNX Attention op (opset 24) GQA path with nonpad_kv_seqlen (Flash Attention). + + Requires SM80+: tests explicitly force Flash via ORT_DISABLE_FLASH_ATTENTION=0. + """ + + @parameterized.expand(gqa_nonpad_kv_seqlen_test_cases()) + def test_gqa_nonpad_kv_seqlen_flash(self, name, config, seqlens): + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cuda") + + parity_check_gqa_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CUDAExecutionProvider", device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + rtol=rtol["fp16"], + atol=atol["fp16"], ) - parity_check_gqa_past_with_padding( + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionGQANonpadKVSeqlenMEA(unittest.TestCase): + """Test ONNX Attention op (opset 24) GQA path with nonpad_kv_seqlen (Memory Efficient Attention).""" + + @parameterized.expand(gqa_nonpad_kv_seqlen_test_cases()) + def test_gqa_nonpad_kv_seqlen_mea(self, name, config, seqlens): + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cuda") + + parity_check_gqa_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +class TestONNXAttentionGQANonpadKVSeqlenCPU(unittest.TestCase): + """Test ONNX Attention op (opset 24) GQA path with nonpad_kv_seqlen on CPU (includes zero_seqlen).""" + + @parameterized.expand(gqa_nonpad_kv_seqlen_cpu_test_cases()) + def test_gqa_nonpad_kv_seqlen_cpu(self, name, config, seqlens): + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cpu") + + parity_check_gqa_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CPUExecutionProvider", + device="cpu", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +# ################################################################################################# +# GQA 4D BNSH Format Tests +# ################################################################################################# + + +def gqa_4d_bnsh_test_cases(): + """Generate test cases for GQA with 4D BNSH input format.""" + return [ + ( + "prompt_nomask", + AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + use_4d_bnsh=True, + ), + ), + ( + "prompt_smallhead", + AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + use_4d_bnsh=True, + ), + ), + ] + + +def gqa_4d_bnsh_past_test_cases(): + """Generate test cases for GQA decode with 4D BNSH input format.""" + return [ + ( + "decode_nomask", + AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + use_4d_bnsh=True, + ), + ), + ] + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping 4D BNSH tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) +class TestONNXAttentionGQA4DBNSH(unittest.TestCase): + """ + Test GQA with 4D BNSH input format [batch, num_heads, seq, head_size]. + + Requires SM80+: tests explicitly force Flash via ORT_DISABLE_FLASH_ATTENTION=0. + The C++ attention op detects 4D inputs and sets transpose_output=false. + Flash/MEA always expect BSNH, so the dispatcher transposes Q internally. + """ + + @parameterized.expand(gqa_4d_bnsh_test_cases()) + def test_gqa_4d_bnsh_prompt(self, name, config): + parity_check_gqa_prompt( config=config, - past_seqlens=past_seqlens, ep="CUDAExecutionProvider", device="cuda", torch_type=torch.float16, ort_type=TensorProto.FLOAT16, + causal=True, rtol=rtol["fp16"], atol=atol["fp16"], ) + @parameterized.expand(gqa_4d_bnsh_past_test_cases()) + def test_gqa_4d_bnsh_decode(self, name, config): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +# ################################################################################################# +# GQA Float Additive Mask Tests +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping float mask tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionGQAFloatMask(unittest.TestCase): + """ + Test GQA with float additive attention mask (not bool) during prompt. + + This exercises MEA's GQA expansion + float bias path. The GQA path converts + the additive mask to attention bias for MEA cutlass FMHA. + """ + + def test_gqa_prompt_float_mask_4d(self): + """Test GQA prompt with 4D float additive mask.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + k = torch.randn(2, 16, 2, 128, device=device, dtype=torch_type) * 0.2 + v = torch.randn_like(k) * 0.2 + + # Create additive mask with padding pattern: batch 0 has 10 valid, batch 1 full + seqlens = torch.tensor([10, 16], dtype=torch.int32, device=device) + attn_mask = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=16, + mask_dims=4, + q_seq_len=16, + num_heads=8, + device=device, + dtype=torch_type, + ) + + # Zero padded KV positions + k[0, 10:, :, :] = 0 + v[0, 10:, :, :] = 0 + + # Reference + attn_bias_ref = attn_mask + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path (MEA handles GQA+float mask) + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_ort = out_ort.reshape(2, 16, 8, 128) + + # Zero padded output for comparison + out_ort[0, 10:, :, :] = 0 + out_ref[0, 10:, :, :] = 0 + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index daa644f40ff41..5cb1e7b7c50b3 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -5,20 +5,27 @@ # -------------------------------------------------------------------------- """ -Tests for ONNX Attention op (opset 23) — MHA path (kv_num_heads == q_num_heads). +Tests for ONNX Attention op (opset 23+) — MHA path (kv_num_heads == q_num_heads). The MHA path in attention.cc is exercised when kv_num_heads == q_num_heads. -It uses the unfused attention kernel and supports: +On Ampere+ GPUs with fp16/bf16, the dispatch cascade routes to flash attention +or memory efficient attention first. The unfused kernel handles fp32 and +cases where flash/MEA are ineligible. Tests below marked "unfused" explicitly +disable flash and MEA to verify the unfused kernel. + +Supports: - float32, float16, bfloat16 - - 3D inputs (BSNH format) + - 3D inputs (BSNH format) and 4D inputs (BNSH format) - Causal and non-causal attention - Self-attention and cross-attention (kv_seq != q_seq) - - Additive attention bias (NOT boolean masks) + - Additive attention bias (and boolean masks converted to additive bias) - Past KV cache - - 2D, 3D, 4D additive masks with broadcasting + - 2D, 3D, 4D additive/bool masks with broadcasting """ +import os import unittest +from unittest.mock import patch import numpy import torch @@ -677,13 +684,12 @@ def parity_check_mha_prompt_with_bool_mask( device=device, ) - # Create 2D key_padding_mask for reference (per-batch, shape [batch, total_seq]) - key_padding_mask = create_boolean_mask_from_seqlens( - seqlens=effective_seqlens, - total_seq_len=config.kv_sequence_length, - mask_dims=2, - device=device, - ) + # Per-batch key_padding_mask [batch, kv_seq] for reference. + # Must NOT use create_boolean_mask_from_seqlens(..., mask_dims=2) here because that + # returns [q_seq, total_seq] using only the first batch's seqlen, which is wrong + # when effective_seqlens vary per batch (4D mask case). + arange_kv = torch.arange(config.kv_sequence_length, device=device).unsqueeze(0) + key_padding_mask = arange_kv < effective_seqlens.unsqueeze(1) # [batch, kv_seq] # --- PyTorch Reference Path --- out_ref, _ = attention_ref( @@ -920,5 +926,571 @@ def test_mha_bool_mask_fp16(self, name, config): ) +# ################################################################################################# +# Parity Check with nonpad_kv_seqlen (Opset 24) +# ################################################################################################# + + +def parity_check_mha_prompt_with_nonpad_kv_seqlen( + config: AttentionConfig, + nonpad_seqlens: torch.Tensor, + ep, + device, + torch_type, + ort_type, + rtol, + atol, + std=0.2, +): + """ + Parity check for ONNX Attention op (opset 24) MHA path with nonpad_kv_seqlen. + + nonpad_kv_seqlen tells the op how many KV positions per batch are valid. + Positions beyond the valid length are treated as padding and masked out. + Cannot be used together with past_key/past_value. + """ + torch.manual_seed(0) + + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) * std + + # Zero out padded positions in K, V for proper comparison + for b in range(config.batch_size): + valid_len = nonpad_seqlens[b].item() + if valid_len < config.kv_sequence_length: + k[b, valid_len:, :, :] = 0 + v[b, valid_len:, :, :] = 0 + + # Per-batch key_padding_mask [batch, kv_seq] for reference + arange_kv = torch.arange(config.kv_sequence_length, device=device).unsqueeze(0) + key_padding_mask = arange_kv < nonpad_seqlens.unsqueeze(1).to(device) # [batch, kv_seq] + + out_ref, _ = attention_ref( + q=q, + k=k, + v=v, + key_padding_mask=key_padding_mask, + causal=config.is_causal == 1, + softcap=config.softcap, + ) + + # ORT path: use nonpad_kv_seqlen (int64 tensor) + nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) + + out, present_k, present_v = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=None, + ep=ep, + device=device, + ort_type=ort_type, + nonpad_kv_seqlen=nonpad_kv_seqlen_tensor, + ) + + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + + # When nonpad_kv_seqlen=0 for a batch, all KV positions are masked → softmax yields NaN. + # Zero out those batches in both ORT and reference for comparison. + for b in range(config.batch_size): + if nonpad_seqlens[b].item() == 0: + out[b, :, :, :] = 0 + out_ref[b, :, :, :] = 0 + + out_np = out.to(torch.float32).detach().cpu().numpy() + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +def mha_nonpad_kv_seqlen_test_cases(): + """ + Generate test cases for ONNX Attention op (opset 24) MHA path with nonpad_kv_seqlen. + + MHA supports partial masking in prompt mode via FlashAttentionForExternalKVCache. + Full-length tests verify the no-masking case; partial-length tests verify actual masking. + """ + h = 128 + sq = 16 + skv = 16 + n = 8 + + seqlen_scenarios = [ + (1, [16], "single_batch"), + (2, [16, 16], "full_len"), + (2, [3, 5], "partial_mask"), + (4, [16, 16, 16, 16], "multi_batch"), + ] + + for batch_size, seqlens, label in seqlen_scenarios: + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n, + head_size=h, + is_causal=0, + has_nonpad_kv_seqlen=True, + attn_mask_type="additive", + ) + name = f"b{batch_size}_sq{sq}_skv{skv}_nh{n}_h{h}_{label}" + yield name, config, seqlens + + # Causal variation with full length + config_c = AttentionConfig( + batch_size=2, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n, + head_size=h, + is_causal=1, + has_nonpad_kv_seqlen=True, + attn_mask_type="additive", + ) + yield f"b2_sq{sq}_skv{skv}_nh{n}_h{h}_causal", config_c, [16, 16] + + +def mha_nonpad_kv_seqlen_cpu_test_cases(): + """CPU-only test cases including zero_seqlen (triggers CUDA_KERNEL_ASSERT in debug builds).""" + yield from mha_nonpad_kv_seqlen_test_cases() + + h = 128 + sq = 16 + skv = 16 + n = 8 + config = AttentionConfig( + batch_size=2, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n, + head_size=h, + is_causal=0, + has_nonpad_kv_seqlen=True, + attn_mask_type="additive", + ) + yield f"b2_sq{sq}_skv{skv}_nh{n}_h{h}_zero_seqlen", config, [0, 5] + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") +class TestONNXAttentionMHANonpadKVSeqlen(unittest.TestCase): + """Test ONNX Attention op (opset 24) MHA path with nonpad_kv_seqlen on CUDA.""" + + @parameterized.expand(mha_nonpad_kv_seqlen_test_cases()) + def test_mha_nonpad_kv_seqlen_fp16(self, name, config, seqlens): + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cuda") + + parity_check_mha_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @parameterized.expand(mha_nonpad_kv_seqlen_test_cases()) + def test_mha_nonpad_kv_seqlen_fp32(self, name, config, seqlens): + config.kv_cache_type = "float32" + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cuda") + + parity_check_mha_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +class TestONNXAttentionMHANonpadKVSeqlenCPU(unittest.TestCase): + """Test ONNX Attention op (opset 24) MHA path with nonpad_kv_seqlen on CPU (includes zero_seqlen).""" + + @parameterized.expand(mha_nonpad_kv_seqlen_cpu_test_cases()) + def test_mha_nonpad_kv_seqlen_cpu(self, name, config, seqlens): + config.kv_cache_type = "float32" + nonpad_seqlens = torch.tensor(seqlens, dtype=torch.int64, device="cpu") + + parity_check_mha_prompt_with_nonpad_kv_seqlen( + config=config, + nonpad_seqlens=nonpad_seqlens, + ep="CPUExecutionProvider", + device="cpu", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +# ################################################################################################# +# Unfused Kernel Tests +# ################################################################################################# + + +def mha_unfused_test_cases(): + """ + Generate test cases that force the unfused attention kernel. + + On Ampere+ GPUs, fp16/bf16 normally route to flash attention. By disabling + both flash and MEA, we force the unfused path even for fp16. These tests + verify the unfused kernel handles MHA correctly. + """ + cases = [ + ( + "prompt_causal", + AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=64, + is_causal=1, + attn_mask_type="additive", + ), + ), + ( + "prompt_noncausal", + AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=64, + is_causal=0, + attn_mask_type="additive", + ), + ), + ( + "decode_causal", + AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=8, + kv_num_heads=8, + head_size=64, + is_causal=1, + attn_mask_type="additive", + ), + ), + ] + return cases + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping unfused tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION": "1"}) +class TestONNXAttentionMHAUnfused(unittest.TestCase): + """ + Test the unfused attention kernel by disabling flash and MEA. + + On Ampere+ GPUs, fp16 normally routes to flash attention. These tests + disable both flash and MEA to exercise the unfused path. + """ + + @parameterized.expand(mha_unfused_test_cases()) + def test_mha_unfused_fp16(self, name, config): + if "decode" in name: + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=config.is_causal == 1, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + else: + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=config.is_causal == 1, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +# ################################################################################################# +# Broadcast Mask (1,1,q,kv) Tests +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") +class TestONNXAttentionMHABroadcastMask(unittest.TestCase): + """ + Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + + This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that + the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + """ + + def test_mha_broadcast_mask_additive(self): + """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=128, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + broadcast_mask_batch=True, + broadcast_mask_heads=True, + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + v = torch.randn_like(k) * 0.2 + + # Create (1,1,q,kv) additive mask: lower-triangular causal pattern + mask_filter = float(torch.finfo(torch_type).min) + mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) + for i in range(16): + mask_2d[i, i + 1 :] = mask_filter + attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) + + # Reference: expand to full (B, H, Q, K) + attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(2, 16, 8, 128) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# 2D Mask Broadcast Regression Test +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") +class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): + """ + Regression test for 2D mask [q_seq, total_seq] broadcast correctness. + + Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts + over batch and heads. This test uses batch_size > q_seq with a non-uniform + mask (different values per row) to verify correct broadcast behavior. + + The old bug indexed the 2D mask by batch index instead of query position, + causing OOB reads when batch_size > q_seq. + """ + + def test_2d_additive_mask_batch_gt_qseq(self): + """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + config = AttentionConfig( + batch_size=4, + q_sequence_length=2, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + mask_filter_value = torch.finfo(torch_type).min + + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * 0.2 + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * 0.2 + ) + v = torch.randn_like(k) * 0.2 + + # Create a non-uniform 2D causal-style mask [q_seq=2, kv_seq=8]: + # Row 0: attend to positions 0-3, mask out 4-7 + # Row 1: attend to positions 0-5, mask out 6-7 + attn_mask = torch.zeros(config.q_sequence_length, config.kv_sequence_length, device=device, dtype=torch_type) + attn_mask[0, 4:] = mask_filter_value + attn_mask[1, 6:] = mask_filter_value + + # Reference: broadcast [2, 8] → [4, 4, 2, 8] (same mask for all batches/heads) + attn_bias_ref = attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1) + + # PyTorch reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + def test_2d_bool_mask_batch_gt_qseq(self): + """2D bool mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + config = AttentionConfig( + batch_size=4, + q_sequence_length=2, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="bool", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * 0.2 + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * 0.2 + ) + v = torch.randn_like(k) * 0.2 + + # Create non-uniform 2D bool mask [q_seq=2, kv_seq=8]: + # Row 0: True for positions 0-3, False for 4-7 + # Row 1: True for positions 0-5, False for 6-7 + attn_mask = torch.ones(config.q_sequence_length, config.kv_sequence_length, device=device, dtype=torch.bool) + attn_mask[0, 4:] = False + attn_mask[1, 6:] = False + + # Zero out K/V at padded positions (use row 0's pattern since 2D mask broadcasts) + # For bool mask, the effective seqlen for all batches comes from row 0 (most restrictive) + # Actually for cross-attention with different masking per query, just zero out nothing + # The reference uses key_padding_mask for padding, or we can use attn_bias directly + mask_filter_value = torch.finfo(torch_type).min + attn_bias_ref = torch.where( + attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1), + torch.tensor(0.0, dtype=torch_type, device=device), + torch.tensor(mask_filter_value, dtype=torch_type, device=device), + ) + + # PyTorch reference with explicit bias + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py new file mode 100644 index 0000000000000..a6a115bb12213 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py @@ -0,0 +1,979 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +Tests for TensorScatter(opset 24) + Attention(opset 24) pattern. + +Demonstrates a decode step where new KV entries are scattered into a +pre-allocated cache via TensorScatter, then Attention uses the updated +KV cache with nonpad_kv_seqlen to mask out padding positions. + +Uses IO Binding for in-place KV cache updates, matching the real-world LLM +inference pattern where KV cache buffers are pre-allocated on the device +and reused across decode steps. + +The graph looks like: + + key_cache (B, S, kv_hidden) ──────────┐ + new_k (B, q_seq, kv_hidden) ──────────┤ + write_indices (B,) ───────────────────┤ + ├─ TensorScatter(axis=1) ─→ updated_key_cache ─┐ + │ + value_cache (B, S, kv_hidden) ────────┐ │ + new_v (B, q_seq, kv_hidden) ──────────┤ │ + write_indices (B,) ──────────────────┤ │ + ├─ TensorScatter(axis=1) ─→ updated_value_cache ┤ + │ + Q (B, q_seq, q_hidden) ──────────────┬─ Attention(opset 24) ←──────────────────────────┘ + nonpad_kv_seqlen (B,) ──────────────┘ │ + ├─ output + ├─ present_key + └─ present_value + +IO Binding enables in-place cache updates: the same OrtValue buffer is bound as +both TensorScatter input (key_cache/value_cache) and output +(updated_key_cache/updated_value_cache), avoiding unnecessary copies. + +CUDA support: + - GQA path (kv_num_heads != q_num_heads) uses flash attention for external KV cache (fp16/bf16) + - MHA path (kv_num_heads == q_num_heads) uses flash attention for fp16/bf16, + unfused attention_bias fallback for fp32 +""" + +import math +import unittest + +import numpy +import torch +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers + +# ################################################################################################# +# Helper Functions +# ################################################################################################# + + +def has_cuda_provider(): + return "CUDAExecutionProvider" in get_available_providers() + + +def has_cuda_device(min_capability: int = 53): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor >= min_capability + + +def has_flash_attention(): + """Return True if the CUDA device meets the SM80+ requirement for Flash Attention.""" + return has_cuda_device(80) + + +def numpy_attention_ref(q, k, v, nonpad_kv_seqlen, is_causal=False, attn_bias=None): + """ + NumPy reference implementation of scaled dot-product attention with padding mask. + + Args: + q: Query [batch, q_seq, num_heads, head_size] + k: Key [batch, kv_seq, kv_num_heads, head_size] + v: Value [batch, kv_seq, kv_num_heads, head_size] + nonpad_kv_seqlen: [batch] — number of valid KV positions per batch + is_causal: whether to apply causal masking + attn_bias: optional additive attention bias, broadcastable to [batch, num_heads, q_seq, kv_seq] + + Returns: + output: [batch, q_seq, num_heads, head_size] + """ + batch_size, q_seq, num_heads, head_size = q.shape + _, kv_seq, kv_num_heads, _ = k.shape + groups = num_heads // kv_num_heads + + # Repeat KV heads for GQA + if groups > 1: + k = numpy.repeat(k, groups, axis=2) + v = numpy.repeat(v, groups, axis=2) + + scale = 1.0 / math.sqrt(head_size) + + # scores: [batch, num_heads, q_seq, kv_seq] + q_t = numpy.transpose(q, (0, 2, 1, 3)) + k_t = numpy.transpose(k, (0, 2, 3, 1)) + scores = numpy.matmul(q_t, k_t) * scale + + # Apply nonpad_kv_seqlen mask: positions >= valid_len get -inf + for b in range(batch_size): + valid_len = int(nonpad_kv_seqlen[b]) + if valid_len < kv_seq: + scores[b, :, :, valid_len:] = -numpy.inf + + # Apply additive attention bias (from attn_mask conversion) + if attn_bias is not None: + scores = scores + attn_bias + + # Apply causal mask + if is_causal: + for sq in range(q_seq): + offset = kv_seq - q_seq + for sk in range(kv_seq): + if sk > sq + offset: + scores[:, :, sq, sk] = -numpy.inf + + # Softmax along last axis + # Handle all-masked rows: if entire row is -inf, softmax gives nan; we want 0. + # This happens when nonpad_kv_seqlen=0 for a batch (all KV positions masked). + # Callers zero out those batches in both ORT and reference outputs for comparison. + max_scores = numpy.max(scores, axis=-1, keepdims=True) + # Clip -inf max to 0 to avoid nan in exp + max_scores = numpy.where(numpy.isinf(max_scores) & (max_scores < 0), 0.0, max_scores) + exp_scores = numpy.exp(scores - max_scores) + sum_exp = numpy.sum(exp_scores, axis=-1, keepdims=True) + sum_exp = numpy.where(sum_exp == 0.0, 1.0, sum_exp) + attention = exp_scores / sum_exp + + # output: [batch, num_heads, q_seq, head_size] + v_t = numpy.transpose(v, (0, 2, 1, 3)) + output = numpy.matmul(attention, v_t) + + # Transpose back: [batch, q_seq, num_heads, head_size] + output = numpy.transpose(output, (0, 2, 1, 3)) + return output + + +def build_tensorscatter_attention_graph( + batch_size, + total_kv_seq_len, + q_seq_len, + q_num_heads, + kv_num_heads, + head_size, + ort_type, + is_causal=0, +): + """ + Build ONNX graph: TensorScatter(opset 24) → Attention(opset 24). + + TensorScatter uses write_indices [B] to scatter new KV entries into cache + at per-batch positions. Attention uses updated cache with nonpad_kv_seqlen + to mask padding. + + The graph exposes updated_key_cache and updated_value_cache as graph outputs + to enable in-place buffer binding via IO Binding. + + Inputs: + 0: key_cache [B, total_kv_seq_len, kv_hidden] + 1: value_cache [B, total_kv_seq_len, kv_hidden] + 2: new_k [B, q_seq_len, kv_hidden] + 3: new_v [B, q_seq_len, kv_hidden] + 4: write_indices [B] (int64 — per-batch write position) + 5: query [B, q_seq_len, q_hidden] + 6: nonpad_kv_seqlen [B] (int64 — valid KV length after scatter) + + Outputs: + 0: output [B, q_seq_len, q_hidden] + 1: present_key [B, kv_num_heads, total_kv_seq_len, head_size] + 2: present_value [B, kv_num_heads, total_kv_seq_len, head_size] + 3: updated_key_cache [B, total_kv_seq_len, kv_hidden] + 4: updated_value_cache [B, total_kv_seq_len, kv_hidden] + """ + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + + # TensorScatter for key cache update (axis=1: sequence dim in [B, S, H]) + scatter_k_node = helper.make_node( + "TensorScatter", + inputs=["key_cache", "new_k", "write_indices"], + outputs=["updated_key_cache"], + name="TensorScatterKey", + axis=1, + ) + + # TensorScatter for value cache update + scatter_v_node = helper.make_node( + "TensorScatter", + inputs=["value_cache", "new_v", "write_indices"], + outputs=["updated_value_cache"], + name="TensorScatterValue", + axis=1, + ) + + # Attention node with nonpad_kv_seqlen + attention_node = helper.make_node( + "Attention", + inputs=[ + "query", + "updated_key_cache", + "updated_value_cache", + "", # attn_mask + "", # past_key + "", # past_value + "nonpad_kv_seqlen", + ], + outputs=["output", "present_key", "present_value"], + name="Attention_0", + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + softcap=0.0, + qk_matmul_output_mode=0, + domain="", + ) + + # Graph inputs + cache_shape = [batch_size, total_kv_seq_len, kv_hidden] + graph_inputs = [ + helper.make_tensor_value_info("key_cache", ort_type, cache_shape), + helper.make_tensor_value_info("value_cache", ort_type, cache_shape), + helper.make_tensor_value_info("new_k", ort_type, [batch_size, q_seq_len, kv_hidden]), + helper.make_tensor_value_info("new_v", ort_type, [batch_size, q_seq_len, kv_hidden]), + helper.make_tensor_value_info("write_indices", TensorProto.INT64, [batch_size]), + helper.make_tensor_value_info("query", ort_type, [batch_size, q_seq_len, q_hidden]), + helper.make_tensor_value_info("nonpad_kv_seqlen", TensorProto.INT64, [batch_size]), + ] + + # Graph outputs: Attention outputs + TensorScatter outputs for in-place binding + present_shape = [batch_size, kv_num_heads, total_kv_seq_len, head_size] + graph_outputs = [ + helper.make_tensor_value_info("output", ort_type, [batch_size, q_seq_len, q_hidden]), + helper.make_tensor_value_info("present_key", ort_type, present_shape), + helper.make_tensor_value_info("present_value", ort_type, present_shape), + helper.make_tensor_value_info("updated_key_cache", ort_type, cache_shape), + helper.make_tensor_value_info("updated_value_cache", ort_type, cache_shape), + ] + + graph = helper.make_graph( + [scatter_k_node, scatter_v_node, attention_node], + "TensorScatterAttention_Graph", + graph_inputs, + graph_outputs, + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 24)]) + return model.SerializeToString() + + +def run_tensorscatter_attention( + batch_size, + total_kv_seq_len, + q_seq_len, + q_num_heads, + kv_num_heads, + head_size, + nonpad_seqlens, + scatter_positions, + ep, + torch_type, + ort_type, + is_causal=0, + std=0.2, +): + """ + Run TensorScatter + Attention test with IO Binding and compare against NumPy reference. + + Uses IO Binding to: + 1. Pre-allocate KV cache as OrtValues on the target device + 2. Bind the same OrtValue as both TensorScatter input and output (in-place update) + 3. Feed the updated cache to Attention + 4. Pre-allocate output buffers on the target device + + Args: + scatter_positions: list of ints per batch — the write index for TensorScatter. + nonpad_seqlens: list of ints per batch — valid KV length AFTER scatter. + is_causal: 1 for causal attention, 0 for non-causal. + """ + torch.manual_seed(42) + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + np_type = numpy.float16 if torch_type == torch.float16 else numpy.float32 + + # Generate test data as numpy arrays via torch for reproducible seeding + key_cache_np = (torch.randn(batch_size, total_kv_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + value_cache_np = (torch.randn(batch_size, total_kv_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + + # Zero out padding positions in cache + for b in range(batch_size): + old_valid = max(0, nonpad_seqlens[b] - q_seq_len) + if old_valid < total_kv_seq_len: + key_cache_np[b, old_valid:, :] = 0 + value_cache_np[b, old_valid:, :] = 0 + + new_k_np = (torch.randn(batch_size, q_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + new_v_np = (torch.randn(batch_size, q_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + query_np = (torch.randn(batch_size, q_seq_len, q_hidden, dtype=torch_type) * std).numpy() + write_indices_np = numpy.array(scatter_positions, dtype=numpy.int64) + nonpad_kv_seqlen_np = numpy.array(nonpad_seqlens, dtype=numpy.int64) + + # --- NumPy reference --- + # Compute reference in float32 for accuracy + key_cache_ref = key_cache_np.astype(numpy.float32).copy() + value_cache_ref = value_cache_np.astype(numpy.float32).copy() + new_k_ref = new_k_np.astype(numpy.float32) + new_v_ref = new_v_np.astype(numpy.float32) + + for b in range(batch_size): + pos = scatter_positions[b] + for t in range(q_seq_len): + key_cache_ref[b, pos + t, :] = new_k_ref[b, t, :] + value_cache_ref[b, pos + t, :] = new_v_ref[b, t, :] + + # Reshape to BSNH for reference attention + q_ref = query_np.astype(numpy.float32).reshape(batch_size, q_seq_len, q_num_heads, head_size) + k_ref = key_cache_ref.reshape(batch_size, total_kv_seq_len, kv_num_heads, head_size) + v_ref = value_cache_ref.reshape(batch_size, total_kv_seq_len, kv_num_heads, head_size) + + ref_output = numpy_attention_ref(q_ref, k_ref, v_ref, nonpad_seqlens, is_causal=bool(is_causal)) + ref_output_3d = ref_output.reshape(batch_size, q_seq_len, q_hidden) + + # Compute expected present_key/present_value: BSNH → BNSH transpose of updated cache. + # Attention op with no past_key simply reshapes+transposes K/V to [B, H, S, D]. + ref_present_k = k_ref.transpose(0, 2, 1, 3) # [B, kv_num_heads, total_kv_seq_len, head_size] + ref_present_v = v_ref.transpose(0, 2, 1, 3) + + # --- ORT execution with IO Binding --- + onnx_model_str = build_tensorscatter_attention_graph( + batch_size=batch_size, + total_kv_seq_len=total_kv_seq_len, + q_seq_len=q_seq_len, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + ort_type=ort_type, + is_causal=is_causal, + ) + + sess_options = SessionOptions() + session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + + # Determine device for OrtValue allocation + ort_device = "cuda" if "CUDA" in ep else "cpu" + device_id = 0 + + # Create OrtValues for inputs on target device + key_cache_ort = OrtValue.ortvalue_from_numpy(key_cache_np, ort_device, device_id) + value_cache_ort = OrtValue.ortvalue_from_numpy(value_cache_np, ort_device, device_id) + new_k_ort = OrtValue.ortvalue_from_numpy(new_k_np, ort_device, device_id) + new_v_ort = OrtValue.ortvalue_from_numpy(new_v_np, ort_device, device_id) + write_indices_ort = OrtValue.ortvalue_from_numpy(write_indices_np, ort_device, device_id) + query_ort = OrtValue.ortvalue_from_numpy(query_np, ort_device, device_id) + nonpad_ort = OrtValue.ortvalue_from_numpy(nonpad_kv_seqlen_np, ort_device, device_id) + + # Pre-allocate output buffers on target device + present_shape = [batch_size, kv_num_heads, total_kv_seq_len, head_size] + output_ort = OrtValue.ortvalue_from_shape_and_type( + [batch_size, q_seq_len, q_hidden], np_type, ort_device, device_id + ) + present_k_ort = OrtValue.ortvalue_from_shape_and_type(present_shape, np_type, ort_device, device_id) + present_v_ort = OrtValue.ortvalue_from_shape_and_type(present_shape, np_type, ort_device, device_id) + + # Set up IO binding + io_binding = session.io_binding() + + # Bind all inputs + io_binding.bind_ortvalue_input("key_cache", key_cache_ort) + io_binding.bind_ortvalue_input("value_cache", value_cache_ort) + io_binding.bind_ortvalue_input("new_k", new_k_ort) + io_binding.bind_ortvalue_input("new_v", new_v_ort) + io_binding.bind_ortvalue_input("write_indices", write_indices_ort) + io_binding.bind_ortvalue_input("query", query_ort) + io_binding.bind_ortvalue_input("nonpad_kv_seqlen", nonpad_ort) + + # Bind Attention outputs to pre-allocated buffers + io_binding.bind_ortvalue_output("output", output_ort) + io_binding.bind_ortvalue_output("present_key", present_k_ort) + io_binding.bind_ortvalue_output("present_value", present_v_ort) + + # Bind TensorScatter outputs to the SAME OrtValues as inputs (in-place update). + # TensorScatter declares MayInplace(0, 0), so ORT will skip the copy when + # input and output share the same buffer. + io_binding.bind_ortvalue_output("updated_key_cache", key_cache_ort) + io_binding.bind_ortvalue_output("updated_value_cache", value_cache_ort) + + # Execute with IO binding + io_binding.synchronize_inputs() + session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # Read results from pre-bound OrtValues + output_result = output_ort.numpy() + present_k_result = present_k_ort.numpy() + present_v_result = present_v_ort.numpy() + + return output_result, ref_output_3d, present_k_result, present_v_result, ref_present_k, ref_present_v + + +# ################################################################################################# +# Test Case Generator +# ################################################################################################# + +# Shared test dimensions +_HEAD_SIZE = 64 +_TOTAL_KV_SEQ_LEN = 8 + +_GQA_CASES = [ + # (batch, q_seq, q_heads, kv_heads, scatter_positions, nonpad_seqlens, label) + (1, 1, 8, 2, [3], [4], "gqa_batch1"), + (2, 1, 8, 2, [2, 4], [3, 5], "gqa_diff_lens"), + (2, 1, 8, 2, [4, 4], [5, 5], "gqa_same_lens"), + (2, 1, 8, 2, [0, 3], [1, 4], "gqa_one_empty"), + (2, 1, 8, 2, [7, 7], [8, 8], "gqa_full_len"), + # Additional GQA ratios + (2, 1, 16, 4, [2, 5], [3, 6], "gqa_16h_4kvh"), + (2, 1, 6, 3, [3, 3], [4, 4], "gqa_6h_3kvh"), +] + +_MHA_CASES = [ + (1, 1, 4, 4, [3], [4], "mha_batch1"), + (2, 1, 4, 4, [2, 4], [3, 5], "mha_diff_lens"), + (2, 1, 4, 4, [4, 4], [5, 5], "mha_same_lens"), + (2, 1, 4, 4, [0, 3], [1, 4], "mha_one_empty"), + (2, 1, 4, 4, [7, 7], [8, 8], "mha_full_len"), +] + + +def _make_test_params(cases, is_causal): + """Convert raw case tuples into parameterized test parameter tuples.""" + causal_str = "causal" if is_causal else "noncausal" + for batch, q_seq, q_heads, kv_heads, scatter_pos, seqlens, label in cases: + name = f"b{batch}_qs{q_seq}_qh{q_heads}_kvh{kv_heads}_h{_HEAD_SIZE}_{label}_{causal_str}" + yield ( + name, + batch, + q_seq, + q_heads, + kv_heads, + _HEAD_SIZE, + _TOTAL_KV_SEQ_LEN, + scatter_pos, + seqlens, + is_causal, + ) + + +def cpu_test_cases(): + """CPU: all modes, non-causal and causal (both GQA and MHA work without restrictions).""" + yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=0) + yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=1) + + +def cuda_fp16_test_cases(): + """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly.""" + yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=0) + yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=1) + + +def cuda_fp32_test_cases(): + """CUDA fp32: MHA only. GQA requires fp16/bf16, and flash attention requires fp16/bf16. + fp32 MHA uses the unfused attention_bias fallback path.""" + yield from _make_test_params(_MHA_CASES, is_causal=0) + yield from _make_test_params(_MHA_CASES, is_causal=1) + + +# ################################################################################################# +# Test Classes +# ################################################################################################# + +# Default tolerances (CUDA fp16/fp32 need looser tolerances due to TF32 and reduced precision) +rtol = {"fp16": 5e-3, "fp32": 5e-3} +atol = {"fp16": 5e-3, "fp32": 5e-3} +# CPU fp32 has no TF32 — use tighter tolerance +cpu_fp32_rtol = 1e-5 +cpu_fp32_atol = 1e-5 + + +class TestTensorScatterAttentionCPU(unittest.TestCase): + """Test TensorScatter + Attention (opset 24) on CPU with float32 and IO Binding.""" + + @parameterized.expand(cpu_test_cases()) + def test_tensorscatter_attention_cpu_fp32( + self, + name, + batch, + q_seq, + q_heads, + kv_heads, + head_size, + total_kv, + scatter_pos, + seqlens, + is_causal, + ): + output, ref_output, present_k, present_v, ref_present_k, ref_present_v = run_tensorscatter_attention( + batch_size=batch, + total_kv_seq_len=total_kv, + q_seq_len=q_seq, + q_num_heads=q_heads, + kv_num_heads=kv_heads, + head_size=head_size, + nonpad_seqlens=seqlens, + scatter_positions=scatter_pos, + ep="CPUExecutionProvider", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + is_causal=is_causal, + ) + numpy.testing.assert_allclose(output, ref_output, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + numpy.testing.assert_allclose(present_k, ref_present_k, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + numpy.testing.assert_allclose(present_v, ref_present_v, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping tests.") +class TestTensorScatterAttentionCUDAFP16(unittest.TestCase): + """Test TensorScatter + Attention (opset 24) on CUDA with float16 and IO Binding. + + On SM80+ Flash Attention is used; on SM75+ MEA handles the fallback; + on older GPUs the unfused path runs. The cascade in attention.cc picks + the best available backend automatically. + """ + + @parameterized.expand(cuda_fp16_test_cases()) + def test_tensorscatter_attention_cuda_fp16( + self, + name, + batch, + q_seq, + q_heads, + kv_heads, + head_size, + total_kv, + scatter_pos, + seqlens, + is_causal, + ): + output, ref_output, present_k, present_v, ref_present_k, ref_present_v = run_tensorscatter_attention( + batch_size=batch, + total_kv_seq_len=total_kv, + q_seq_len=q_seq, + q_num_heads=q_heads, + kv_num_heads=kv_heads, + head_size=head_size, + nonpad_seqlens=seqlens, + scatter_positions=scatter_pos, + ep="CUDAExecutionProvider", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + is_causal=is_causal, + ) + numpy.testing.assert_allclose(output, ref_output, rtol=rtol["fp16"], atol=atol["fp16"]) + numpy.testing.assert_allclose(present_k, ref_present_k, rtol=rtol["fp16"], atol=atol["fp16"]) + numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp16"], atol=atol["fp16"]) + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping tests.") +class TestTensorScatterAttentionCUDAFP32(unittest.TestCase): + """Test TensorScatter + Attention (opset 24) on CUDA with float32 and IO Binding. + + Only MHA cases: CUDA GQA path requires float16. + """ + + @parameterized.expand(cuda_fp32_test_cases()) + def test_tensorscatter_attention_cuda_fp32( + self, + name, + batch, + q_seq, + q_heads, + kv_heads, + head_size, + total_kv, + scatter_pos, + seqlens, + is_causal, + ): + output, ref_output, present_k, present_v, ref_present_k, ref_present_v = run_tensorscatter_attention( + batch_size=batch, + total_kv_seq_len=total_kv, + q_seq_len=q_seq, + q_num_heads=q_heads, + kv_num_heads=kv_heads, + head_size=head_size, + nonpad_seqlens=seqlens, + scatter_positions=scatter_pos, + ep="CUDAExecutionProvider", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + is_causal=is_causal, + ) + numpy.testing.assert_allclose(output, ref_output, rtol=rtol["fp32"], atol=atol["fp32"]) + numpy.testing.assert_allclose(present_k, ref_present_k, rtol=rtol["fp32"], atol=atol["fp32"]) + numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp32"], atol=atol["fp32"]) + + +# ################################################################################################# +# TensorScatter + Attention with nonpad_kv_seqlen + attn_mask (T26 / T31) +# ################################################################################################# + + +def build_tensorscatter_attention_graph_with_mask( + batch_size, + total_kv_seq_len, + q_seq_len, + q_num_heads, + kv_num_heads, + head_size, + ort_type, + mask_type, + mask_shape, + is_causal=0, +): + """ + Build ONNX graph: TensorScatter(opset 24) → Attention(opset 24) with both + nonpad_kv_seqlen AND attn_mask inputs. + + Args: + mask_type: TensorProto type for the mask (BOOL or same as ort_type for additive). + mask_shape: shape of the attn_mask tensor (e.g., [q_seq, total_kv_seq] for 2D). + """ + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + + scatter_k_node = helper.make_node( + "TensorScatter", + inputs=["key_cache", "new_k", "write_indices"], + outputs=["updated_key_cache"], + name="TensorScatterKey", + axis=1, + ) + scatter_v_node = helper.make_node( + "TensorScatter", + inputs=["value_cache", "new_v", "write_indices"], + outputs=["updated_value_cache"], + name="TensorScatterValue", + axis=1, + ) + + attention_node = helper.make_node( + "Attention", + inputs=[ + "query", + "updated_key_cache", + "updated_value_cache", + "attn_mask", + "", # past_key + "", # past_value + "nonpad_kv_seqlen", + ], + outputs=["output", "present_key", "present_value"], + name="Attention_0", + is_causal=is_causal, + kv_num_heads=kv_num_heads, + q_num_heads=q_num_heads, + softcap=0.0, + qk_matmul_output_mode=0, + domain="", + ) + + cache_shape = [batch_size, total_kv_seq_len, kv_hidden] + graph_inputs = [ + helper.make_tensor_value_info("key_cache", ort_type, cache_shape), + helper.make_tensor_value_info("value_cache", ort_type, cache_shape), + helper.make_tensor_value_info("new_k", ort_type, [batch_size, q_seq_len, kv_hidden]), + helper.make_tensor_value_info("new_v", ort_type, [batch_size, q_seq_len, kv_hidden]), + helper.make_tensor_value_info("write_indices", TensorProto.INT64, [batch_size]), + helper.make_tensor_value_info("query", ort_type, [batch_size, q_seq_len, q_hidden]), + helper.make_tensor_value_info("nonpad_kv_seqlen", TensorProto.INT64, [batch_size]), + helper.make_tensor_value_info("attn_mask", mask_type, mask_shape), + ] + + present_shape = [batch_size, kv_num_heads, total_kv_seq_len, head_size] + graph_outputs = [ + helper.make_tensor_value_info("output", ort_type, [batch_size, q_seq_len, q_hidden]), + helper.make_tensor_value_info("present_key", ort_type, present_shape), + helper.make_tensor_value_info("present_value", ort_type, present_shape), + helper.make_tensor_value_info("updated_key_cache", ort_type, cache_shape), + helper.make_tensor_value_info("updated_value_cache", ort_type, cache_shape), + ] + + graph = helper.make_graph( + [scatter_k_node, scatter_v_node, attention_node], + "TensorScatterAttentionWithMask_Graph", + graph_inputs, + graph_outputs, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 24)]) + return model.SerializeToString() + + +def run_tensorscatter_attention_with_mask( + batch_size, + total_kv_seq_len, + q_seq_len, + q_num_heads, + kv_num_heads, + head_size, + nonpad_seqlens, + scatter_positions, + mask_positions_to_block, + use_bool_mask, + ep, + torch_type, + ort_type, + is_causal=0, + std=0.2, +): + """ + Run TensorScatter + Attention test with BOTH nonpad_kv_seqlen AND attn_mask. + + Args: + mask_positions_to_block: list of KV position indices to mask out via attn_mask + (applied uniformly across all batches since 2D mask broadcasts). + use_bool_mask: True for bool mask, False for float additive mask. + """ + torch.manual_seed(42) + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + np_type = numpy.float16 if torch_type == torch.float16 else numpy.float32 + + # Generate test data + key_cache_np = (torch.randn(batch_size, total_kv_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + value_cache_np = (torch.randn(batch_size, total_kv_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + + for b in range(batch_size): + old_valid = max(0, nonpad_seqlens[b] - q_seq_len) + if old_valid < total_kv_seq_len: + key_cache_np[b, old_valid:, :] = 0 + value_cache_np[b, old_valid:, :] = 0 + + new_k_np = (torch.randn(batch_size, q_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + new_v_np = (torch.randn(batch_size, q_seq_len, kv_hidden, dtype=torch_type) * std).numpy() + query_np = (torch.randn(batch_size, q_seq_len, q_hidden, dtype=torch_type) * std).numpy() + write_indices_np = numpy.array(scatter_positions, dtype=numpy.int64) + nonpad_kv_seqlen_np = numpy.array(nonpad_seqlens, dtype=numpy.int64) + + # Create attn_mask: 2D [q_seq, total_kv_seq] + if use_bool_mask: + mask_np = numpy.ones((q_seq_len, total_kv_seq_len), dtype=numpy.bool_) + for pos in mask_positions_to_block: + mask_np[:, pos] = False + mask_ort_type = TensorProto.BOOL + # Reference: convert bool to additive bias for numpy_attention_ref + ref_bias = numpy.zeros((1, 1, q_seq_len, total_kv_seq_len), dtype=numpy.float32) + for pos in mask_positions_to_block: + ref_bias[:, :, :, pos] = -numpy.inf + else: + mask_np = numpy.zeros((q_seq_len, total_kv_seq_len), dtype=np_type) + for pos in mask_positions_to_block: + mask_np[:, pos] = numpy.finfo(np_type).min + mask_ort_type = ort_type + ref_bias = numpy.zeros((1, 1, q_seq_len, total_kv_seq_len), dtype=numpy.float32) + for pos in mask_positions_to_block: + ref_bias[:, :, :, pos] = float(numpy.finfo(np_type).min) + + # --- NumPy reference --- + key_cache_ref = key_cache_np.astype(numpy.float32).copy() + value_cache_ref = value_cache_np.astype(numpy.float32).copy() + new_k_ref = new_k_np.astype(numpy.float32) + new_v_ref = new_v_np.astype(numpy.float32) + + for b in range(batch_size): + pos = scatter_positions[b] + for t in range(q_seq_len): + key_cache_ref[b, pos + t, :] = new_k_ref[b, t, :] + value_cache_ref[b, pos + t, :] = new_v_ref[b, t, :] + + q_ref = query_np.astype(numpy.float32).reshape(batch_size, q_seq_len, q_num_heads, head_size) + k_ref = key_cache_ref.reshape(batch_size, total_kv_seq_len, kv_num_heads, head_size) + v_ref = value_cache_ref.reshape(batch_size, total_kv_seq_len, kv_num_heads, head_size) + + ref_output = numpy_attention_ref(q_ref, k_ref, v_ref, nonpad_seqlens, is_causal=bool(is_causal), attn_bias=ref_bias) + ref_output_3d = ref_output.reshape(batch_size, q_seq_len, q_hidden) + + # Compute expected present_key/present_value: BSNH → BNSH transpose of updated cache. + ref_present_k = k_ref.transpose(0, 2, 1, 3) + ref_present_v = v_ref.transpose(0, 2, 1, 3) + + # --- ORT execution --- + mask_shape = [q_seq_len, total_kv_seq_len] + onnx_model_str = build_tensorscatter_attention_graph_with_mask( + batch_size=batch_size, + total_kv_seq_len=total_kv_seq_len, + q_seq_len=q_seq_len, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + ort_type=ort_type, + mask_type=mask_ort_type, + mask_shape=mask_shape, + is_causal=is_causal, + ) + + sess_options = SessionOptions() + session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + + ort_device = "cuda" if "CUDA" in ep else "cpu" + device_id = 0 + + key_cache_ort = OrtValue.ortvalue_from_numpy(key_cache_np, ort_device, device_id) + value_cache_ort = OrtValue.ortvalue_from_numpy(value_cache_np, ort_device, device_id) + new_k_ort = OrtValue.ortvalue_from_numpy(new_k_np, ort_device, device_id) + new_v_ort = OrtValue.ortvalue_from_numpy(new_v_np, ort_device, device_id) + write_indices_ort = OrtValue.ortvalue_from_numpy(write_indices_np, ort_device, device_id) + query_ort = OrtValue.ortvalue_from_numpy(query_np, ort_device, device_id) + nonpad_ort = OrtValue.ortvalue_from_numpy(nonpad_kv_seqlen_np, ort_device, device_id) + mask_ort = OrtValue.ortvalue_from_numpy(mask_np, ort_device, device_id) + + present_shape = [batch_size, kv_num_heads, total_kv_seq_len, head_size] + output_ort = OrtValue.ortvalue_from_shape_and_type( + [batch_size, q_seq_len, q_hidden], np_type, ort_device, device_id + ) + present_k_ort = OrtValue.ortvalue_from_shape_and_type(present_shape, np_type, ort_device, device_id) + present_v_ort = OrtValue.ortvalue_from_shape_and_type(present_shape, np_type, ort_device, device_id) + + io_binding = session.io_binding() + io_binding.bind_ortvalue_input("key_cache", key_cache_ort) + io_binding.bind_ortvalue_input("value_cache", value_cache_ort) + io_binding.bind_ortvalue_input("new_k", new_k_ort) + io_binding.bind_ortvalue_input("new_v", new_v_ort) + io_binding.bind_ortvalue_input("write_indices", write_indices_ort) + io_binding.bind_ortvalue_input("query", query_ort) + io_binding.bind_ortvalue_input("nonpad_kv_seqlen", nonpad_ort) + io_binding.bind_ortvalue_input("attn_mask", mask_ort) + + io_binding.bind_ortvalue_output("output", output_ort) + io_binding.bind_ortvalue_output("present_key", present_k_ort) + io_binding.bind_ortvalue_output("present_value", present_v_ort) + io_binding.bind_ortvalue_output("updated_key_cache", key_cache_ort) + io_binding.bind_ortvalue_output("updated_value_cache", value_cache_ort) + + io_binding.synchronize_inputs() + session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + output_result = output_ort.numpy() + present_k_result = present_k_ort.numpy() + present_v_result = present_v_ort.numpy() + return output_result, ref_output_3d, present_k_result, present_v_result, ref_present_k, ref_present_v + + +# Test cases for nonpad_kv_seqlen + attn_mask combination +# Format: (batch, q_seq, q_heads, kv_heads, scatter_pos, nonpad_seqlens, mask_positions, label) +_NONPAD_MASK_CASES = [ + # Single batch: mask position 1 within valid range + (1, 1, 4, 4, [3], [4], [1], "mha_b1_mask1pos"), + # Multi-batch with different valid lengths, mask position 0 + (2, 1, 4, 4, [2, 4], [3, 5], [0], "mha_b2_mask_pos0"), + # GQA with mask blocking two positions + (2, 1, 8, 2, [2, 4], [3, 5], [1, 2], "gqa_b2_mask2pos"), + # Larger batch with varied lengths + (4, 1, 4, 4, [1, 3, 5, 7], [2, 4, 6, 8], [0, 3], "mha_b4_varied"), + # GQA with full valid length, mask some positions + (2, 1, 8, 2, [7, 7], [8, 8], [2, 5], "gqa_b2_full_mask2"), +] + + +def _make_mask_test_params(cases, use_bool_mask): + """Generate parameterized test cases for nonpad + mask tests.""" + mask_str = "bool" if use_bool_mask else "float" + for batch, q_seq, q_heads, kv_heads, scatter_pos, seqlens, mask_pos, label in cases: + name = f"b{batch}_qh{q_heads}_kvh{kv_heads}_{label}_{mask_str}" + yield ( + name, + batch, + q_seq, + q_heads, + kv_heads, + _HEAD_SIZE, + _TOTAL_KV_SEQ_LEN, + scatter_pos, + seqlens, + mask_pos, + use_bool_mask, + ) + + +def nonpad_mask_cpu_test_cases(): + """CPU test cases for nonpad_kv_seqlen + attn_mask, both bool and float masks.""" + yield from _make_mask_test_params(_NONPAD_MASK_CASES, use_bool_mask=True) + yield from _make_mask_test_params(_NONPAD_MASK_CASES, use_bool_mask=False) + + +class TestTensorScatterAttentionWithMaskCPU(unittest.TestCase): + """Test TensorScatter + Attention with both nonpad_kv_seqlen and attn_mask on CPU. + + Exercises the T26 fix: graceful fallback from Flash to MEA when both inputs present. + On CPU, both masks compose additively in the reference attention implementation. + """ + + @parameterized.expand(nonpad_mask_cpu_test_cases()) + def test_nonpad_with_mask_cpu( + self, + name, + batch, + q_seq, + q_heads, + kv_heads, + head_size, + total_kv, + scatter_pos, + seqlens, + mask_pos, + use_bool_mask, + ): + output, ref_output, present_k, present_v, ref_present_k, ref_present_v = run_tensorscatter_attention_with_mask( + batch_size=batch, + total_kv_seq_len=total_kv, + q_seq_len=q_seq, + q_num_heads=q_heads, + kv_num_heads=kv_heads, + head_size=head_size, + nonpad_seqlens=seqlens, + scatter_positions=scatter_pos, + mask_positions_to_block=mask_pos, + use_bool_mask=use_bool_mask, + ep="CPUExecutionProvider", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + ) + numpy.testing.assert_allclose(output, ref_output, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + numpy.testing.assert_allclose(present_k, ref_present_k, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + numpy.testing.assert_allclose(present_v, ref_present_v, rtol=cpu_fp32_rtol, atol=cpu_fp32_atol) + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping tests.") +class TestTensorScatterAttentionWithMaskCUDA(unittest.TestCase): + """Test TensorScatter + Attention with both nonpad_kv_seqlen and attn_mask on CUDA. + + Exercises the MEA path which supports seqlen_k_ptr + attn_bias simultaneously. + Flash is excluded when both inputs are present; MEA handles the combination. + """ + + @parameterized.expand(_make_mask_test_params(_NONPAD_MASK_CASES, use_bool_mask=True)) + def test_nonpad_with_bool_mask_cuda_fp16( + self, + name, + batch, + q_seq, + q_heads, + kv_heads, + head_size, + total_kv, + scatter_pos, + seqlens, + mask_pos, + use_bool_mask, + ): + output, ref_output, present_k, present_v, ref_present_k, ref_present_v = run_tensorscatter_attention_with_mask( + batch_size=batch, + total_kv_seq_len=total_kv, + q_seq_len=q_seq, + q_num_heads=q_heads, + kv_num_heads=kv_heads, + head_size=head_size, + nonpad_seqlens=seqlens, + scatter_positions=scatter_pos, + mask_positions_to_block=mask_pos, + use_bool_mask=use_bool_mask, + ep="CUDAExecutionProvider", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + ) + numpy.testing.assert_allclose(output, ref_output, rtol=rtol["fp16"], atol=atol["fp16"]) + numpy.testing.assert_allclose(present_k, ref_present_k, rtol=rtol["fp16"], atol=atol["fp16"]) + numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp16"], atol=atol["fp16"]) + + +if __name__ == "__main__": + unittest.main()