diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index abdcc81586909..75b43c363f371 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -655,7 +655,8 @@ Do not modify directly.*
|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**
or
*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
+|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*in* nonpad_kv_seqlen:**tensor(int64)**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**
or
*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|24+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
+|||23|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index 6c08d7fbd9b3f..8cccc2f1a725c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -132,6 +132,16 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block);
+// BxNxSxH => BxSxNxH
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const float* input, float* output, cudaStream_t stream, const int max_threads_per_block);
+
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);
+
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block);
+
template
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length,
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
index e7177987fa2d1..e5f1ffca251b7 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
@@ -393,6 +393,26 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c
max_threads_per_block, false, input, output);
}
+// BxNxSxH => BxSxNxH (BNSH to BSNH) — reverse of Transpose_BSNH_to_BNSH.
+// Reuses the existing TransposeCtx kernel which does exactly this transformation.
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) {
+ return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
+ max_threads_per_block, false, input, output);
+}
+
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) {
+ return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
+ max_threads_per_block, false, input, output);
+}
+
+Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
+ const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block) {
+ return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
+ max_threads_per_block, false, input, output);
+}
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
index 1b68d70617744..29bb4fba6a09a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
@@ -258,6 +258,33 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;
+ // Bias stride alignment check: route to the unaligned kernel when bias strides
+ // don't satisfy the aligned kernel's kAlignmentQ requirement.
+ //
+ // kAlignmentQ is template-dependent (kernel_forward.h:414):
+ // isAligned=true: kAlignmentQ = DefaultConfig::kAlignmentA (8 for fp16/bf16 SM75+)
+ // isAligned=false: kAlignmentQ = GemmType::kMinimumAlignment (4 for fp16/bf16 SM75+)
+ // So check_supported (line 632) enforces DIFFERENT thresholds per path.
+ //
+ // The ONNX Attention kernel (core/providers/cuda/llm/attention.cc) gates MEA eligibility
+ // at kMinimumAlignment (4), allowing strides like 12 that the unaligned kernel handles.
+ // Without this check, such inputs dispatch to the aligned kernel where 12%8≠0 crashes.
+ // Contrib MHA gates at 4*sizeof(T)=8 for fp16, making this check redundant there.
+ if (params.attn_bias != nullptr) {
+ int num_keys = params.kv_sequence_length;
+ int num_queries = params.sequence_length;
+ int bias_strideM = num_keys;
+ // Broadcast dimensions use stride=0, which satisfies any alignment (0 % N == 0).
+ int bias_strideH = params.broadcast_attn_bias_dim_1 ? 0 : num_queries * num_keys;
+ int bias_strideB = params.broadcast_attn_bias_dim_0
+ ? 0
+ : ((params.broadcast_attn_bias_dim_1 ? 1 : params.num_heads) * num_queries * num_keys);
+ is_aligned = is_aligned &&
+ bias_strideM % AlignedAK::kAlignmentQ == 0 &&
+ (params.num_heads <= 1 || bias_strideH % AlignedAK::kAlignmentQ == 0) &&
+ (params.batch_size <= 1 || bias_strideB % AlignedAK::kAlignmentQ == 0);
+ }
+
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha(params);
}));
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 961c80748d228..d2f74386b701e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -1147,6 +1147,24 @@ template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>(
GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data);
#endif
+// Explicit instantiations for cross-TU usage by core/providers/cuda/llm/attention.cc
+template Status LaunchUngroup<__half>(
+ const GroupQueryAttentionParameters& parameters,
+ float2* k_buff, float2* v_buff,
+ const float2* k_og, const float2* v_og,
+ const int buff_seqlen, const int og_seqlen,
+ const bool is_bsnh,
+ cudaStream_t stream,
+ const int max_threads_per_block);
+template Status LaunchUngroup<__nv_bfloat16>(
+ const GroupQueryAttentionParameters& parameters,
+ float2* k_buff, float2* v_buff,
+ const float2* k_og, const float2* v_og,
+ const int buff_seqlen, const int og_seqlen,
+ const bool is_bsnh,
+ cudaStream_t stream,
+ const int max_threads_per_block);
+
template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);
template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index 78b061837e402..125ab8f76132c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -122,6 +122,16 @@ struct GQABufferRequirements {
}
};
+template
+// Also used by ONNX Attention (core/providers/cuda/llm/attention.cc) for GQA head expansion in MEA path.
+Status LaunchUngroup(const GroupQueryAttentionParameters& parameters,
+ float2* k_buff, float2* v_buff,
+ const float2* k_og, const float2* v_og,
+ const int buff_seqlen, const int og_seqlen,
+ const bool is_bsnh,
+ cudaStream_t stream,
+ const int max_threads_per_block);
+
Status LaunchGetSequenceLengths(
const int* total_seq_lens_minus_one,
int* past_seq_lens,
diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h
index e7df1a078472a..efc908603a990 100644
--- a/onnxruntime/core/providers/cpu/llm/attention_helper.h
+++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h
@@ -27,7 +27,8 @@ inline Status ComputeOutputShapeForAttention(
TensorShape& y_shape,
TensorShape& present_key_shape,
TensorShape& present_value_shape,
- TensorShape& output_qk_shape) {
+ TensorShape& output_qk_shape,
+ bool skip_nonpad_data_validation = false) {
ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
"Q, K, and V inputs must not be null");
int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions());
@@ -90,6 +91,13 @@ inline Status ComputeOutputShapeForAttention(
parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]);
+
+ // Validate mask second-to-last dim matches q_sequence_length (same check as 4D path).
+ // For 2D mask [A, B]: A must equal q_seq. For 3D mask [A, B, C]: B must equal q_seq.
+ ORT_ENFORCE(attn_mask == nullptr ||
+ attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[1],
+ "inconsistent q_sequence_length (between attn_mask and Q)");
+
parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads;
parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]);
parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads;
@@ -115,11 +123,14 @@ inline Status ComputeOutputShapeForAttention(
parameters.has_nonpad_kv_seqlen = true;
parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data();
// Validate each value is in [0, total_sequence_length].
- for (int i = 0; i < parameters.batch_size; ++i) {
- ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
- parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
- "nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
- " is out of range [0, ", parameters.total_sequence_length, "]");
+ // Skip when data is on GPU (CUDA path sets skip_nonpad_data_validation=true).
+ if (!skip_nonpad_data_validation) {
+ for (int i = 0; i < parameters.batch_size; ++i) {
+ ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
+ parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
+ "nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
+ " is out of range [0, ", parameters.total_sequence_length, "]");
+ }
}
} else {
parameters.has_nonpad_kv_seqlen = false;
@@ -127,6 +138,11 @@ inline Status ComputeOutputShapeForAttention(
}
ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention.");
+ // TODO: The ONNX spec allows attn_mask last dim to be shorter than total_sequence_length,
+ // with positions beyond the mask padded with -inf. Currently we enforce exact match.
+ // To support: change == to <=, allocate padded buffer, fill remainder with -inf.
+ // See ONNX spec: 'The last dimension can also be shorter than total_sequence_length
+ // and will be padded to total_sequence_length with negative infinity.'
ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
"inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
ORT_ENFORCE(attn_mask == nullptr ||
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index bf6fcc7ccf0a8..92e1e5108cb8a 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1592,9 +1592,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
// Opset 23.
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
@@ -1633,6 +1633,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention);
#endif
@@ -2671,9 +2674,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
// Opset 23
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2711,6 +2714,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
// Opset 24
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc
index 8fc7a49827087..4f7fd6a664bdc 100644
--- a/onnxruntime/core/providers/cuda/llm/attention.cc
+++ b/onnxruntime/core/providers/cuda/llm/attention.cc
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/attention.h"
#include "core/providers/cpu/llm/attention_helper.h"
@@ -20,10 +19,11 @@ namespace onnxruntime {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Attention, \
kOnnxDomain, \
23, \
+ 23, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
@@ -36,11 +36,26 @@ REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
+#define REGISTER_KERNEL_TYPED_24(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ Attention, \
+ kOnnxDomain, \
+ 24, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("U", BuildKernelDefConstraints()), \
+ Attention);
+
+REGISTER_KERNEL_TYPED_24(float)
+REGISTER_KERNEL_TYPED_24(MLFloat16)
+REGISTER_KERNEL_TYPED_24(BFloat16)
+
template
Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) {
is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1;
- // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs.
- // The dimension is not yet known. If not specified, the inputs is assumed to be 4D.
kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0));
q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0));
int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0));
@@ -52,589 +67,1100 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) {
qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKMask ||
qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftCap ||
qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kQKSoftMax,
- "qk_matmul_output_mode must be 0, 1, 2, or 3.");
- // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed.
+ "qk_matmul_output_mode must be one of: kNone(-1), kQK(0), kQKMask(1), kQKSoftCap(2), kQKSoftMax(3).");
scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN());
softcap_ = info.GetAttrOrDefault("softcap", 0.0f);
softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0));
ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified");
+
+ const auto* kernel_options = this->GetAttentionKernelOptions();
+ disable_flash_attention_ = std::is_same::value || !kernel_options->UseFlashAttention();
+ disable_memory_efficient_attention_ = !kernel_options->UseEfficientAttention();
}
+// ============================================================================
+// Transpose helpers: eliminate repeated if-constexpr type-switch blocks.
+// T is the ORT type (MLFloat16, BFloat16, float). The helpers map T to the
+// corresponding CUDA type via ToCudaType::MappedType and forward to the
+// overloaded Transpose functions in contrib_ops.
+// ============================================================================
+
template
-Status Attention::ComputeInternal(OpKernelContext* context) const {
- const Tensor* Q = context->Input(0);
- const Tensor* K = context->Input(1);
- const Tensor* V = context->Input(2);
- const Tensor* attn_mask = context->Input(3);
- const Tensor* past_key = context->Input(4);
- const Tensor* past_value = context->Input(5);
- const Tensor* nonpad_kv_seqlen = context->Input(6); // optional, Opset 24
+static Status TransposeBNSHtoBSNH(int batch_size, int sequence_length,
+ int num_heads, int head_size,
+ const void* input, void* output,
+ cudaStream_t stream, int max_threads_per_block) {
+ using CudaT = typename ToCudaType::MappedType;
+ return onnxruntime::contrib::cuda::Transpose_BNSH_to_BSNH(
+ batch_size, sequence_length, num_heads, head_size,
+ reinterpret_cast(input),
+ reinterpret_cast(output),
+ stream, max_threads_per_block);
+}
- attention_helper::AttentionParameters parameters;
- TensorShape y_shape;
- TensorShape present_key_shape;
- TensorShape present_value_shape;
- TensorShape output_qk_shape;
+template
+static Status TransposeBSNHtoBNSH(int batch_size, int sequence_length,
+ int num_heads, int head_size,
+ const void* input, void* output,
+ cudaStream_t stream, int max_threads_per_block) {
+ using CudaT = typename ToCudaType::MappedType;
+ return onnxruntime::contrib::cuda::Transpose_BSNH_to_BNSH(
+ batch_size, sequence_length, num_heads, head_size,
+ reinterpret_cast(input),
+ reinterpret_cast(output),
+ stream, max_threads_per_block);
+}
- ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention(
- Q,
- K,
- V,
- attn_mask,
- past_key,
- past_value,
- nonpad_kv_seqlen,
- is_causal_,
- softcap_,
- softmax_precision_,
- qk_matmul_output_mode_,
- kv_num_heads_,
- q_num_heads_,
- scale_,
- parameters,
- y_shape,
- present_key_shape,
- present_value_shape,
- output_qk_shape)
- .IsOK(),
- "Output shapes for Attention could not be computed.");
+// ============================================================================
+// ConvertAttnMaskToBias: shared helper for mask→additive bias conversion.
+// Used by both Flash (nonpad+mask) and MEA paths to avoid code duplication.
+// Converts bool masks to additive bias (true→0, false→mask_filter_value),
+// passes float masks through directly, and sets broadcast flags from mask shape.
+// ============================================================================
+template
+Status Attention::ConvertAttnMaskToBias(
+ OpKernelContext* context,
+ const Tensor* attn_mask,
+ cudaStream_t cuda_stream,
+ int max_threads_per_block,
+ IAllocatorUniquePtr& converted_mask_buffer,
+ const void*& attn_bias_data,
+ bool& broadcast_bias_dim_0,
+ bool& broadcast_bias_dim_1) const {
+ if (attn_mask->IsDataType()) {
+ using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type;
+ int64_t num_elements = attn_mask->Shape().Size();
+ converted_mask_buffer = GetScratchBuffer(
+ num_elements * sizeof(NativeCudaT), context->GetComputeStream());
+ float mask_filter_value = static_cast(std::numeric_limits::lowest());
+ ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias(
+ attn_mask->Data(),
+ reinterpret_cast(converted_mask_buffer.get()),
+ num_elements, mask_filter_value, cuda_stream,
+ max_threads_per_block));
+ attn_bias_data = converted_mask_buffer.get();
+ } else {
+ attn_bias_data = attn_mask->Data();
+ }
- Tensor* Y = context->Output(0, y_shape);
- Tensor* present_key = context->Output(1, present_key_shape);
- Tensor* present_value = context->Output(2, present_value_shape);
- Tensor* output_qk = context->Output(3, output_qk_shape);
+ size_t mask_dims = attn_mask->Shape().NumDimensions();
+ auto dims = attn_mask->Shape().GetDims();
+ if (mask_dims == 2) {
+ broadcast_bias_dim_0 = true;
+ broadcast_bias_dim_1 = true;
+ } else if (mask_dims == 3) {
+ broadcast_bias_dim_0 = true;
+ broadcast_bias_dim_1 = dims[0] == 1;
+ } else {
+ broadcast_bias_dim_0 = dims[0] == 1;
+ broadcast_bias_dim_1 = dims[1] == 1;
+ }
+ return Status::OK();
+}
- // To reuse the existing attention-cuda implementation in contrib ops,
- // map the parameters to contribop_parameters (MHA).
- onnxruntime::contrib::AttentionParameters contribop_parameters;
+// ============================================================================
+// RunFlashAttention: Direct flash attention kernel call
+// ============================================================================
+//
+// Flash Attention dispatch paths:
+// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache
+// Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache
+// - Supports: bool mask -> seqlens_k, no mask -> fill past_seq_len
+// - 4D BNSH: transposes Q/K/V to BSNH before kernel
+// Path 3: no past, no mask (prompt) -> mha_fwd
+// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk,
+// (no mask OR bool mask + past OR nonpad_kv_seqlen without mask)
+// Note: softcap and softmax_precision are early-rejected before the cascade.
+// Note: nonpad_kv_seqlen + attn_mask is supported but routes to MEA/unfused,
+// not Flash (Flash has no bias parameter for this combination).
+//
+// PERFORMANCE NOTE: ONNX Attention's internal-cache decode path (past_key/past_value)
+// is ~15-30% slower than contrib GQA's decode path for grouped-query attention workloads.
+// When using external KV cache via TensorScatter + nonpad_kv_seqlen (opset 24), the
+// copy overhead (point 1) is eliminated. The remaining ~5-15% gap is from the missing
+// XQA kernel (point 2).
+//
+// The internal-cache overhead comes from:
+//
+// 1. No past_present_share_buffer: The ONNX Attention spec requires past_key/value
+// shape = (B, H, past_seq, head_size) and present_key/value shape =
+// (B, H, total_seq, head_size) where total_seq = past_seq + kv_seq.
+// Since past and present have different shapes, they cannot share the same buffer.
+// Contrib GQA allows past and present to be the same tensor (in-place append),
+// eliminating the memset + strided copy overhead (~67MB per decode step for typical LLM).
+// This overhead does NOT apply to the external-cache path (TensorScatter +
+// nonpad_kv_seqlen), which bypasses past/present entirely.
+//
+// 2. No XQA kernel: GQA's specialized XQA decode kernel (xqa_loader.h) requires
+// past_present_share_buffer to function. Since ONNX Attention cannot share buffers
+// (see point 1), XQA is fundamentally incompatible with this op's spec design.
+// This accounts for the remaining ~5-15% gap even on the external-cache path.
+//
+// 3. These are spec-level limitations, not implementation gaps. For production LLM
+// inference, the external-cache path (TensorScatter + nonpad_kv_seqlen) is
+// recommended and achieves near-parity with contrib GQA performance.
+//
+template
+Status Attention::RunFlashAttention(
+ OpKernelContext* context,
+ const Tensor* Q, const Tensor* K, const Tensor* V,
+ const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value,
+ const Tensor* nonpad_kv_seqlen,
+ Tensor* Y, Tensor* present_key, Tensor* present_value,
+ const attention_helper::AttentionParameters& parameters) const {
+#if USE_FLASH_ATTENTION
+ auto& device_prop = GetDeviceProp();
+ auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle());
+ const bool is_bf16 = std::is_same::value;
+ const bool is_bsnh = parameters.transpose_output; // 3D inputs → BSNH
- // QKV format: Determine based on input dimensions
- // 3D inputs (B, S, D): Q_K_V_BSNH - will be transposed by PrepareQkv to BNSH
- // transpose_output is true for 3D inputs, false for 4D inputs
- if (!parameters.transpose_output) {
- contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
- contribop_parameters.is_output_bnsh = true;
- } else {
- // 3D inputs in BSNH format (will be transposed)
- contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH;
- contribop_parameters.is_output_bnsh = false;
+ // --- Common buffer allocation ---
+ size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(
+ parameters.q_sequence_length, parameters.batch_size, parameters.q_num_heads);
+
+ auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] =
+ onnxruntime::flash::get_num_splits_and_buffer_sizes(
+ parameters.batch_size, parameters.q_sequence_length,
+ parameters.total_sequence_length, parameters.q_num_heads,
+ parameters.head_size, device_prop.multiProcessorCount);
+
+ auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream());
+ auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream());
+ auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream());
+
+ if (softmax_lse_accum_bytes > 0) {
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0,
+ softmax_lse_accum_bytes, cuda_stream));
+ }
+ if (out_accum_bytes > 0) {
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(out_accum_buffer.get(), 0,
+ out_accum_bytes, cuda_stream));
}
- // Check if this is Group Query Attention (GQA)
- const bool is_gqa = parameters.kv_num_heads != parameters.q_num_heads;
+ // --- Transpose Q from BNSH to BSNH (flash always expects Q as BSNH) ---
+ const void* q_data = Q->Data();
+ IAllocatorUniquePtr q_bsnh_buffer;
+ if (!is_bsnh) {
+ size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length *
+ parameters.q_num_heads * parameters.head_size;
+ q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream());
+ ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(
+ parameters.batch_size, parameters.q_sequence_length,
+ parameters.q_num_heads, parameters.head_size,
+ Q->Data(), q_bsnh_buffer.get(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ q_data = q_bsnh_buffer.get();
+ }
- if (is_gqa) {
- // Use GQA path with Flash Attention or Memory Efficient Attention
- // GQA only supports float16 and bfloat16 types
- if constexpr (std::is_same::value) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "GQA in Attention op (CUDA) does not support float32. "
- "Please use float16 or bfloat16.");
+ // Flash outputs BSNH. If Y expects BNSH, write to scratch then transpose.
+ void* out_data = Y->MutableData();
+ IAllocatorUniquePtr out_bsnh_buffer;
+ if (!is_bsnh) {
+ size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length *
+ parameters.q_num_heads * parameters.v_head_size;
+ out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream());
+ out_data = out_bsnh_buffer.get();
+ }
+
+ bool present_kv_already_populated = false;
+
+ // --- Path 1: nonpad_kv_seqlen (opset 24 external KV cache) ---
+ if (nonpad_kv_seqlen != nullptr) {
+ ORT_ENFORCE(parameters.past_sequence_length == 0,
+ "RunFlashAttention with nonpad_kv_seqlen requires K/V to be the full cache "
+ "(past_sequence_length must be 0, got ",
+ parameters.past_sequence_length, ").");
+
+ auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream());
+ ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK(
+ nonpad_kv_seqlen->Data(),
+ seqlens_k_buffer.get(),
+ parameters.batch_size,
+ parameters.total_sequence_length,
+ cuda_stream,
+ device_prop.maxThreadsPerBlock));
+
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, cuda_stream,
+ const_cast(q_data),
+ const_cast(static_cast(K->Data())),
+ const_cast(static_cast(V->Data())),
+ /*k=*/nullptr, /*v=*/nullptr,
+ out_data,
+ softmax_lse_buffer.get(),
+ const_cast(static_cast(seqlens_k_buffer.get())),
+ /*rotary_cos=*/nullptr, /*rotary_sin=*/nullptr,
+ /*cache_batch_idx=*/nullptr, /*leftpad_k=*/nullptr,
+ /*head_sink=*/nullptr, /*block_table=*/nullptr,
+ parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads,
+ parameters.head_size,
+ parameters.q_sequence_length, parameters.kv_sequence_length,
+ /*seqlen_k_new=*/0, /*rotary_dim=*/0,
+ parameters.scale, parameters.softcap,
+ parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false,
+ /*past_bsnh=*/is_bsnh,
+ static_cast(num_splits),
+ softmax_lse_accum_buffer.get(), out_accum_buffer.get(),
+ /*local_window_size=*/-1, /*is_rotary_interleaved=*/false,
+ /*is_packed_qkv=*/false));
+ }
+ // --- Path 2: Decode with past KV cache ---
+ else if (past_key != nullptr) {
+ ORT_ENFORCE(past_value != nullptr, "past_key requires past_value.");
+ ORT_ENFORCE(present_key != nullptr && present_value != nullptr,
+ "present_key/value outputs are required when past_key is provided.");
+
+ // Zero present buffers before strided copy to avoid stale data in positions
+ // beyond past_seq that mha_fwd_kvcache might read during attention (matching GQA pattern).
+ // NOTE: This memset + strided copy is the main decode overhead vs contrib GQA.
+ // See PERFORMANCE NOTE above for why buffer sharing is impossible under the ONNX spec.
+ const size_t num_kv_rows = static_cast(parameters.batch_size) * parameters.kv_num_heads;
+ const size_t present_k_bytes = num_kv_rows * parameters.total_sequence_length *
+ parameters.head_size * sizeof(T);
+ const size_t present_v_bytes = num_kv_rows * parameters.total_sequence_length *
+ parameters.v_head_size * sizeof(T);
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_key->MutableData(), 0,
+ present_k_bytes, cuda_stream));
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_value->MutableData(), 0,
+ present_v_bytes, cuda_stream));
+
+ // Copy past KV (BNSH) into present buffers (BNSH). Strided copy because
+ // past has [B, N_kv, past_seq, H] and present has [B, N_kv, total_seq, H].
+ const size_t past_k_row_bytes = static_cast(parameters.past_sequence_length) *
+ parameters.head_size * sizeof(T);
+ const size_t present_k_row_bytes = static_cast(parameters.total_sequence_length) *
+ parameters.head_size * sizeof(T);
+ CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync(
+ present_key->MutableData(), present_k_row_bytes,
+ past_key->Data(), past_k_row_bytes,
+ past_k_row_bytes, num_kv_rows,
+ cudaMemcpyDeviceToDevice, cuda_stream));
+
+ const size_t past_v_row_bytes = static_cast(parameters.past_sequence_length) *
+ parameters.v_head_size * sizeof(T);
+ const size_t present_v_row_bytes = static_cast(parameters.total_sequence_length) *
+ parameters.v_head_size * sizeof(T);
+ CUDA_RETURN_IF_ERROR(cudaMemcpy2DAsync(
+ present_value->MutableData(), present_v_row_bytes,
+ past_value->Data(), past_v_row_bytes,
+ past_v_row_bytes, num_kv_rows,
+ cudaMemcpyDeviceToDevice, cuda_stream));
+
+ // seqlens_k: derive per-batch sequence lengths for the KV cache.
+ // mha_fwd_kvcache expects seqlens_k = tokens already in cache BEFORE appending new ones.
+ // When a bool mask is present, it encodes total valid token count (past + new).
+ // Subtract kv_sequence_length to get the pre-append count.
+ auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream());
+ if (attn_mask != nullptr && attn_mask->IsDataType()) {
+ size_t mask_dims = attn_mask->Shape().NumDimensions();
+ auto dims = attn_mask->Shape().GetDims();
+ int64_t mask_dim0 = dims[0];
+ int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0;
+ int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0;
+ // Offset: mask gives total valid count, subtract kv_sequence_length for pre-append count.
+ int seqlen_offset = -parameters.kv_sequence_length;
+ ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK(
+ attn_mask->Data(), seqlens_k_buffer.get(),
+ parameters.batch_size, parameters.total_sequence_length,
+ static_cast(mask_dims), mask_dim0, mask_dim1, mask_dim2,
+ cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset));
} else {
- // GQA only supports 3D inputs (B, S, D) in BSNH format, not 4D inputs (B, num_heads, S, head_size) in BNSH format
- if (!parameters.transpose_output) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "4D QKV inputs (BNSH format) are not supported yet in GQA path of Attention op (CUDA). "
- "Please use 3D inputs (B, S, hidden_size) instead.");
- }
- // For now, GQA doesn't support qk_matmul_output_mode other than kNone
- if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "qk_matmul_output_mode is not supported yet in GQA path of Attention op (CUDA).");
- }
- // GQA doesn't support softmax_precision yet
- if (parameters.softmax_precision != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "softmax_precision is not supported yet in GQA path of Attention op (CUDA).");
- }
- // causal attention is required for GQA
- if (!parameters.is_causal) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "Non-causal attention is not supported yet in GQA path of Attention op (CUDA).");
- }
- // GQA kernel expects K/V input sequence length == Q sequence length (self-attention only)
- // Cross-attention (kv_sequence_length != q_sequence_length) is not supported
- if (parameters.kv_sequence_length != parameters.q_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "Cross-attention (kv_sequence_length != q_sequence_length) is not supported in "
- "GQA path of Attention op (CUDA). kv_sequence_length=",
- parameters.kv_sequence_length, ", q_sequence_length=", parameters.q_sequence_length);
- }
+ ORT_RETURN_IF_ERROR(LaunchFillInt32(seqlens_k_buffer.get(), parameters.past_sequence_length,
+ parameters.batch_size, cuda_stream,
+ device_prop.maxThreadsPerBlock));
+ }
- auto& device_prop = GetDeviceProp();
-
- // Bridge parameters to GroupQueryAttentionParameters
- onnxruntime::contrib::GroupQueryAttentionParameters gqa_parameters;
- gqa_parameters.batch_size = parameters.batch_size;
- gqa_parameters.sequence_length = parameters.q_sequence_length;
- gqa_parameters.seqlen_past_kv_cache = parameters.past_sequence_length;
- gqa_parameters.seqlen_present_kv_cache = parameters.total_sequence_length;
- gqa_parameters.total_sequence_length = parameters.total_sequence_length;
- gqa_parameters.kv_sequence_length = parameters.kv_sequence_length;
- gqa_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
- gqa_parameters.num_heads = parameters.q_num_heads;
- gqa_parameters.head_size = parameters.head_size;
- gqa_parameters.v_head_size = parameters.v_head_size;
- gqa_parameters.kv_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
- gqa_parameters.kv_num_heads = parameters.kv_num_heads;
- gqa_parameters.scale = parameters.scale;
- gqa_parameters.softcap = parameters.softcap;
- gqa_parameters.qkv_format = contribop_parameters.qkv_format;
-
- // Unset or set to default values for GQA-specific fields
- gqa_parameters.rotary_dim = 0; // New Attention op doesn't use rotary embeddings directly
- gqa_parameters.is_unidirectional = true; // GQA requires causal attention
- gqa_parameters.is_packed_qkv = false; // New Attention op has separate Q, K, V inputs
- gqa_parameters.is_subsequent_prompt = false;
- gqa_parameters.is_first_prompt = parameters.past_sequence_length == 0;
- gqa_parameters.do_rotary = false; // New Attention op doesn't use rotary embeddings
- gqa_parameters.rotary_interleaved = false;
- gqa_parameters.use_smooth_softmax = false;
- gqa_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
- gqa_parameters.past_kv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH;
- gqa_parameters.local_window_size = -1; // No local window for standard attention
- gqa_parameters.zeros_count = 0;
- gqa_parameters.zero_ptr = nullptr;
- gqa_parameters.num_splits = 1;
-
- // Construct GroupQueryAttentionData
- typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT;
- onnxruntime::contrib::cuda::GroupQueryAttentionData gqa_data;
-
- // Scratch buffers for flash/memory efficient attention
- IAllocatorUniquePtr k_buffer;
- IAllocatorUniquePtr v_buffer;
- IAllocatorUniquePtr fmha_buffer;
- IAllocatorUniquePtr unpacked_qkv_buffer;
- IAllocatorUniquePtr seq_lens_buffer;
- IAllocatorUniquePtr seqlens_k_buffer;
-
- // Present KV cache buffers - GQA kernel uses these as working buffers
- // If outputs are not provided, we allocate scratch buffers
- IAllocatorUniquePtr present_key_scratch;
- IAllocatorUniquePtr present_value_scratch;
-
- // Set input pointers
- gqa_data.query = reinterpret_cast(Q->Data());
- gqa_data.key = reinterpret_cast(K->Data());
- gqa_data.value = reinterpret_cast(V->Data());
- gqa_data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data());
- gqa_data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data());
-
- // Set output pointers
- gqa_data.output = reinterpret_cast(Y->MutableData());
-
- // GQA kernel requires present_key/present_value buffers as working storage for KV cache
- // Allocate scratch buffers if outputs are not provided
- size_t present_kv_size = static_cast(parameters.batch_size) *
- static_cast(parameters.kv_num_heads) *
- static_cast(parameters.total_sequence_length) *
- static_cast(parameters.head_size) * sizeof(CudaT);
- if (present_key != nullptr) {
- gqa_data.present_key = reinterpret_cast(present_key->MutableData());
- } else {
- present_key_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream());
- gqa_data.present_key = reinterpret_cast(present_key_scratch.get());
- }
- if (present_value != nullptr) {
- gqa_data.present_value = reinterpret_cast(present_value->MutableData());
- } else {
- present_value_scratch = GetScratchBuffer(present_kv_size, context->GetComputeStream());
- gqa_data.present_value = reinterpret_cast(present_value_scratch.get());
- }
+ // K/V new tokens: mha_fwd_kvcache expects BSNH for k_new/v_new.
+ // When !is_bsnh (4D BNSH input), transpose new tokens to BSNH.
+ const void* k_new = K->Data();
+ const void* v_new = V->Data();
+ IAllocatorUniquePtr k_bsnh_buffer;
+ IAllocatorUniquePtr v_bsnh_buffer;
+ if (!is_bsnh) {
+ size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length *
+ parameters.kv_num_heads * parameters.head_size;
+ size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length *
+ parameters.kv_num_heads * parameters.v_head_size;
+ k_bsnh_buffer = GetScratchBuffer(k_bytes, context->GetComputeStream());
+ v_bsnh_buffer = GetScratchBuffer(v_bytes, context->GetComputeStream());
+ ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(
+ parameters.batch_size, parameters.kv_sequence_length,
+ parameters.kv_num_heads, parameters.head_size,
+ K->Data(), k_bsnh_buffer.get(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(
+ parameters.batch_size, parameters.kv_sequence_length,
+ parameters.kv_num_heads, parameters.v_head_size,
+ V->Data(), v_bsnh_buffer.get(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ k_new = k_bsnh_buffer.get();
+ v_new = v_bsnh_buffer.get();
+ }
- // Compute past_present_share_buffer early since it's needed for flash attention path selection
- gqa_parameters.past_present_share_buffer = (gqa_data.past_key == gqa_data.present_key);
+ // mha_fwd_kvcache: present_key/value as cache (BNSH), K/V as new tokens (BSNH).
+ // The kernel appends new tokens at position seqlens_k[b] and attends to all.
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, cuda_stream,
+ const_cast(q_data),
+ static_cast(present_key->MutableData()),
+ static_cast(present_value->MutableData()),
+ const_cast(k_new), const_cast(v_new),
+ out_data,
+ softmax_lse_buffer.get(),
+ static_cast(seqlens_k_buffer.get()),
+ /*rotary_cos=*/nullptr, /*rotary_sin=*/nullptr,
+ /*cache_batch_idx=*/nullptr, /*leftpad_k=*/nullptr,
+ /*head_sink=*/nullptr, /*block_table=*/nullptr,
+ parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads,
+ parameters.head_size,
+ parameters.q_sequence_length, parameters.total_sequence_length,
+ parameters.kv_sequence_length, /*rotary_dim=*/0,
+ parameters.scale, parameters.softcap,
+ parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false,
+ /*past_bsnh=*/false, // present cache is BNSH
+ static_cast(num_splits),
+ softmax_lse_accum_buffer.get(), out_accum_buffer.get(),
+ /*local_window_size=*/-1, /*is_rotary_interleaved=*/false,
+ /*is_packed_qkv=*/false));
- // Flash Attention buffers
- IAllocatorUniquePtr softmax_lse_buffer;
- IAllocatorUniquePtr softmax_lse_accum_buffer;
- IAllocatorUniquePtr out_accum_buffer;
+ present_kv_already_populated = true;
+ }
+ // --- Path 3: Prompt flash (no past, no mask) ---
+ // Note: prompt with bool mask is handled by MEA (flash_eligible excludes it).
+ else {
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
+ device_prop, cuda_stream,
+ const_cast(q_data),
+ const_cast(static_cast(K->Data())),
+ const_cast(static_cast(V->Data())),
+ out_data,
+ softmax_lse_buffer.get(),
+ parameters.batch_size, parameters.q_num_heads, parameters.kv_num_heads,
+ parameters.head_size,
+ parameters.q_sequence_length, parameters.kv_sequence_length,
+ parameters.scale, parameters.softcap,
+ parameters.is_causal, is_bf16, /*use_smooth_softmax=*/false,
+ static_cast(num_splits),
+ softmax_lse_accum_buffer.get(), out_accum_buffer.get(),
+ is_bsnh));
+ }
- // Check Flash Attention support
-#if USE_FLASH_ATTENTION
- bool use_flash_attention = onnxruntime::flash::is_supported(device_prop,
- gqa_parameters.head_size,
- gqa_parameters.num_heads,
- gqa_parameters.kv_num_heads);
-
- gqa_data.use_flash_attention = use_flash_attention;
- gqa_data.use_flash_attention_fast_decode = use_flash_attention &&
- !gqa_parameters.is_first_prompt &&
- gqa_parameters.past_present_share_buffer;
-
- if (use_flash_attention) {
- // Allocate Flash specific buffers (Softmax LSE, Accum)
- size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(
- gqa_parameters.sequence_length, gqa_parameters.batch_size, gqa_parameters.num_heads);
-
- int num_heads_for_split = gqa_data.use_flash_attention_fast_decode
- ? gqa_parameters.kv_num_heads
- : gqa_parameters.num_heads;
- auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] =
- onnxruntime::flash::get_num_splits_and_buffer_sizes(
- gqa_parameters.batch_size, gqa_parameters.sequence_length,
- gqa_parameters.total_sequence_length, num_heads_for_split,
- gqa_parameters.head_size, device_prop.multiProcessorCount);
-
- gqa_parameters.num_splits = static_cast(num_splits);
-
- if (gqa_data.use_flash_attention_fast_decode && num_splits > 1) {
- // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
- // However, the LSE and Accum buffers must store results for ALL num_heads.
- softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(
- num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length);
- auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; };
- out_accum_bytes = onnxruntime::flash::get_out_accum_size(
- num_splits, gqa_parameters.batch_size, gqa_parameters.num_heads, gqa_parameters.sequence_length,
- round_multiple(gqa_parameters.head_size, 32));
- }
-
- softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream());
- softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream());
- out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream());
-
- gqa_data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get());
- gqa_data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get());
- gqa_data.out_accum = reinterpret_cast(out_accum_buffer.get());
- } else {
- gqa_data.softmax_lse = nullptr;
- gqa_data.softmax_lse_accum = nullptr;
- gqa_data.out_accum = nullptr;
- }
+ // --- Transpose output BSNH → BNSH if input was 4D (BNSH) ---
+ if (!is_bsnh && out_bsnh_buffer != nullptr) {
+ ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(
+ parameters.batch_size, parameters.q_sequence_length,
+ parameters.q_num_heads, parameters.v_head_size,
+ out_bsnh_buffer.get(), Y->MutableData(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ }
+
+ // --- Populate present_key/value (BNSH) from K/V (BSNH) ---
+ // Skip for decode path where mha_fwd_kvcache already populated present buffers.
+ if (!present_kv_already_populated) {
+ if (present_key != nullptr && is_bsnh) {
+ ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(
+ parameters.batch_size, parameters.kv_sequence_length,
+ parameters.kv_num_heads, parameters.head_size,
+ K->Data(), present_key->MutableData(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ } else if (present_key != nullptr && !is_bsnh) {
+ // 4D BNSH prompt: K is already BNSH, just D2D copy to present
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
+ present_key->MutableData(), K->Data(),
+ K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
+ }
+ if (present_value != nullptr && is_bsnh) {
+ ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(
+ parameters.batch_size, parameters.kv_sequence_length,
+ parameters.kv_num_heads, parameters.v_head_size,
+ V->Data(), present_value->MutableData(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ } else if (present_value != nullptr && !is_bsnh) {
+ // 4D BNSH prompt: V is already BNSH, just D2D copy to present
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
+ present_value->MutableData(), V->Data(),
+ V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
+ }
+ }
+
+ return Status::OK();
#else
- gqa_data.use_flash_attention = false;
- gqa_data.use_flash_attention_fast_decode = false;
- gqa_data.softmax_lse = nullptr;
- gqa_data.softmax_lse_accum = nullptr;
- gqa_data.out_accum = nullptr;
+ ORT_UNUSED_PARAMETER(context);
+ ORT_UNUSED_PARAMETER(Q);
+ ORT_UNUSED_PARAMETER(K);
+ ORT_UNUSED_PARAMETER(V);
+ ORT_UNUSED_PARAMETER(attn_mask);
+ ORT_UNUSED_PARAMETER(past_key);
+ ORT_UNUSED_PARAMETER(past_value);
+ ORT_UNUSED_PARAMETER(nonpad_kv_seqlen);
+ ORT_UNUSED_PARAMETER(Y);
+ ORT_UNUSED_PARAMETER(present_key);
+ ORT_UNUSED_PARAMETER(present_value);
+ ORT_UNUSED_PARAMETER(parameters);
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "Flash attention is not available in this build.");
#endif
+}
- // Check Memory Efficient Attention support (fallback if flash attention not available)
+// ============================================================================
+// RunMemoryEfficientAttention: Direct memory-efficient attention kernel call
+// ============================================================================
+//
+// Memory Efficient Attention (cutlass FMHA) dispatch paths:
+// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode
+// Path 2: no past, with mask (prompt) -> standard MEA with additive bias
+// Path 3: no past, no mask (prompt) -> standard MEA
+// Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype,
+// head_size <= 1024), plus: no output_qk, no past_key (decode excluded),
+// bias stride alignment.
+// Note: softcap and softmax_precision are early-rejected before the cascade.
+//
+template
+Status Attention::RunMemoryEfficientAttention(
+ OpKernelContext* context,
+ const Tensor* Q, const Tensor* K, const Tensor* V,
+ const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value,
+ const Tensor* nonpad_kv_seqlen,
+ Tensor* Y, Tensor* present_key, Tensor* present_value,
+ const attention_helper::AttentionParameters& parameters) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
- if (!gqa_data.use_flash_attention) {
- int sm = (device_prop.major * 10) + device_prop.minor;
- bool use_memory_efficient_attention =
- onnxruntime::contrib::cuda::has_memory_efficient_attention(
- sm, std::is_same::value, std::is_same::value,
- gqa_parameters.head_size, gqa_parameters.head_size);
- gqa_data.use_memory_efficient_attention = use_memory_efficient_attention;
-
- // KV buffer for head expansion (when num_heads != kv_num_heads)
- size_t kv_buffer_bytes = (use_memory_efficient_attention &&
- (gqa_parameters.num_heads != gqa_parameters.kv_num_heads))
- ? (sizeof(T) * gqa_parameters.batch_size * gqa_parameters.num_heads *
- gqa_parameters.seqlen_present_kv_cache * gqa_parameters.head_size)
- : 0;
- // FMHA workspace
- size_t fmha_buffer_bytes =
- (use_memory_efficient_attention &&
- onnxruntime::contrib::cuda::MemoryEfficientAttentionParams::need_workspace(
- gqa_parameters.head_size, sizeof(T) == sizeof(float)))
- ? (sizeof(float) * gqa_parameters.batch_size * gqa_parameters.sequence_length *
- gqa_parameters.num_heads * gqa_parameters.head_size)
- : 0;
-
- k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream());
- v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream());
- fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream());
-
- gqa_data.k = reinterpret_cast(k_buffer.get());
- gqa_data.v = reinterpret_cast(v_buffer.get());
- gqa_data.fmha_buffer = reinterpret_cast(fmha_buffer.get());
- } else {
- gqa_data.use_memory_efficient_attention = false;
- gqa_data.k = nullptr;
- gqa_data.v = nullptr;
- gqa_data.fmha_buffer = nullptr;
- }
-#else
- gqa_data.use_memory_efficient_attention = false;
- gqa_data.k = nullptr;
- gqa_data.v = nullptr;
- gqa_data.fmha_buffer = nullptr;
-#endif
+ ORT_UNUSED_PARAMETER(past_key);
+ ORT_UNUSED_PARAMETER(past_value);
+ auto& device_prop = GetDeviceProp();
+ auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle());
+ const bool is_bsnh = parameters.transpose_output;
+ const int sm = device_prop.major * 10 + device_prop.minor;
- // Centralized scratch buffer allocation using GQABufferRequirements
- auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute(
- gqa_parameters,
- false, // use_xqa
- gqa_data.use_flash_attention,
- gqa_data.use_flash_attention_fast_decode,
- gqa_data.use_memory_efficient_attention);
-
- if (buffer_req.qkv_buffer_bytes > 0) {
- unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
- gqa_data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get());
- } else {
- gqa_data.qkv_buffer = nullptr;
- }
+ // Q/K/V pointers — MEA expects BSNH format for Q
+ const void* q_data = Q->Data();
+ const void* k_data = K->Data();
+ const void* v_data = V->Data();
- // Allocate GPU buffer for seqlens_k (total_sequence_length - 1) for GQA compatibility
- // The GQA kernel expects sequence length information for flash/memory efficient attention
- seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream());
- auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle());
-
- // GQA only supports masking, not additive bias.
- // For bool mask, we need to convert it to sequence lengths on GPU.
- // Note: The GQA path interprets 2D bool masks as (batch_size, total_seq_len) since it converts
- // masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from
- // the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
- // 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
- if (attn_mask != nullptr && attn_mask->IsDataType()) {
- // Allocate validation result buffer on GPU
- // Get mask dimensions for broadcasting
- // attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len)
- const auto& mask_shape = attn_mask->Shape();
- int mask_dims = static_cast(mask_shape.NumDimensions());
- int64_t mask_dim0 = 0, mask_dim1 = 0, mask_dim2 = 0;
-
- if (mask_dims == 2) {
- // Shape: (batch_size or 1, total_seq_len)
- mask_dim0 = mask_shape[0];
- mask_dim1 = 0;
- mask_dim2 = 0;
- } else if (mask_dims == 3) {
- // Shape: (num_heads or 1, q_seq_len, total_seq_len)
- mask_dim0 = mask_shape[0];
- mask_dim1 = mask_shape[1];
- mask_dim2 = 0;
- } else if (mask_dims == 4) {
- // Shape: (batch_size or 1, num_heads or 1, q_seq_len, total_seq_len)
- mask_dim0 = mask_shape[0];
- mask_dim1 = mask_shape[1];
- mask_dim2 = mask_shape[2];
- } else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Boolean attn_mask must be 2D, 3D, or 4D. Got ", mask_dims, "D.");
- }
-
- // Launch CUDA kernel to convert mask to seqlens_k and validate
- // Mask validity (right-padding, contiguous) is checked asynchronously via CUDA_KERNEL_ASSERT.
- ORT_RETURN_IF_ERROR(LaunchConvertMaskToSeqlensK(
- attn_mask->Data(),
- seqlens_k_buffer.get(),
- parameters.batch_size,
- parameters.total_sequence_length,
- mask_dims,
- mask_dim0,
- mask_dim1,
- mask_dim2,
- cuda_stream,
- device_prop.maxThreadsPerBlock));
- } else if (attn_mask != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
- "Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA).");
- } else {
- // No mask provided - use full sequence length for all batches
- // seqlens_k is total_sequence_length - 1 for historical reasons (matching GroupQueryAttention convention)
- // Fill on GPU using cudaMemset-like approach or a simple kernel
- std::vector seqlens_k_host(parameters.batch_size, parameters.total_sequence_length - 1);
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(seqlens_k_buffer.get(), seqlens_k_host.data(),
- sizeof(int) * parameters.batch_size,
- cudaMemcpyHostToDevice, cuda_stream));
- }
+ // --- Transpose Q from BNSH to BSNH if 4D input ---
+ IAllocatorUniquePtr q_bsnh_buffer;
+ if (!is_bsnh) {
+ size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length *
+ parameters.q_num_heads * parameters.head_size;
+ q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream());
+ ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(
+ parameters.batch_size, parameters.q_sequence_length,
+ parameters.q_num_heads, parameters.head_size,
+ Q->Data(), q_bsnh_buffer.get(),
+ cuda_stream, device_prop.maxThreadsPerBlock));
+ q_data = q_bsnh_buffer.get();
+ }
+
+ // MEA output is BSNH. If Y expects BNSH, write to scratch then transpose.
+ void* out_data = Y->MutableData();
+ IAllocatorUniquePtr out_bsnh_buffer;
+ if (!is_bsnh) {
+ size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length *
+ parameters.q_num_heads * parameters.v_head_size;
+ out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream());
+ out_data = out_bsnh_buffer.get();
+ }
- // Process seqlens_k to compute past_seq_lens, total_seq_lens, and padded_seq_lens
- // This is always needed for flash/memory efficient attention
- seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream());
- gqa_data.past_seq_lens = seq_lens_buffer.get();
- gqa_data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size;
- gqa_data.padded_seq_lens = gqa_data.total_seq_lens + parameters.batch_size;
-
- ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchGetSequenceLengths(
- seqlens_k_buffer.get(),
- gqa_data.past_seq_lens,
- gqa_data.total_seq_lens,
- gqa_data.padded_seq_lens,
- parameters.batch_size,
- parameters.q_sequence_length,
- gqa_parameters.is_first_prompt,
+ // GQA head expansion: MEA requires matching num_heads for Q/K/V.
+ // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup.
+ const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads;
+ IAllocatorUniquePtr k_expand_buffer;
+ IAllocatorUniquePtr v_expand_buffer;
+
+ if (is_gqa) {
+ // GQA+MEA only works with fp16/bf16 (LaunchUngroup lacks fp32 template instantiation
+ // in group_query_attention_impl.cu).
+ // Use if constexpr to avoid instantiating LaunchUngroup.
+ if constexpr (std::is_same_v) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "GQA with Memory Efficient Attention requires fp16 or bf16, not fp32.");
+ } else {
+ ORT_ENFORCE(parameters.head_size == parameters.v_head_size,
+ "GQA with MEA requires head_size == v_head_size for LaunchUngroup.");
+ ORT_ENFORCE(parameters.head_size % 4 == 0,
+ "GQA with MEA requires head_size divisible by 4 for LaunchUngroup (float2 access).");
+ const size_t expanded_kv_elements = static_cast(parameters.batch_size) *
+ static_cast(parameters.total_sequence_length) *
+ static_cast(parameters.q_num_heads) *
+ static_cast(parameters.head_size);
+ k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream());
+ v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream());
+
+ onnxruntime::contrib::GroupQueryAttentionParameters ungroup_params = {};
+ ungroup_params.batch_size = parameters.batch_size;
+ ungroup_params.num_heads = parameters.q_num_heads;
+ ungroup_params.kv_num_heads = parameters.kv_num_heads;
+ ungroup_params.head_size = parameters.head_size;
+
+ using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type;
+ ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchUngroup(
+ ungroup_params,
+ reinterpret_cast(k_expand_buffer.get()),
+ reinterpret_cast(v_expand_buffer.get()),
+ reinterpret_cast(k_data),
+ reinterpret_cast(v_data),
+ parameters.total_sequence_length,
+ parameters.total_sequence_length,
+ is_bsnh,
cuda_stream,
device_prop.maxThreadsPerBlock));
- // Set GQA-specific fields
- gqa_data.cos_cache = nullptr; // No rotary embeddings
- gqa_data.sin_cache = nullptr;
- gqa_data.head_sink = nullptr;
- gqa_data.position_ids = nullptr;
-
- // Call GQA kernel (with flash or memory efficient attention)
- cublasHandle_t cublas = GetCublasHandle(context);
-
- return onnxruntime::contrib::cuda::QkvToContext(
- device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data);
- } // else (non-float GQA path)
- } else { // MHA path (kv_num_heads == q_num_heads)
- typedef typename ToCudaType::MappedType CudaT;
- contribop_parameters.batch_size = parameters.batch_size;
- contribop_parameters.sequence_length = parameters.q_sequence_length;
- contribop_parameters.kv_sequence_length = parameters.kv_sequence_length;
- contribop_parameters.past_sequence_length = parameters.past_sequence_length;
- contribop_parameters.total_sequence_length = parameters.total_sequence_length;
- // max_sequence_length: For non-buffer-sharing case, this equals total_sequence_length (the present KV cache size)
- contribop_parameters.max_sequence_length = parameters.total_sequence_length;
- contribop_parameters.input_hidden_size = 0; // Not applicable - new Attention op takes pre-projected Q/K/V
- contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size;
- contribop_parameters.head_size = parameters.head_size;
- contribop_parameters.v_head_size = parameters.v_head_size;
- contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size;
- contribop_parameters.num_heads = parameters.q_num_heads;
- contribop_parameters.rotary_dim = 0;
- contribop_parameters.num_splits = 1;
- contribop_parameters.beam_width = 1;
- contribop_parameters.is_unidirectional = parameters.is_causal;
- contribop_parameters.past_present_share_buffer = false; // New Attention op doesn't share buffer
- contribop_parameters.is_packed_qkv = false;
- contribop_parameters.do_rotary = false;
-
- // The new Attention op uses attn_mask as attention_bias (additive bias), not as key_padding_mask
- // So mask_type should always be MASK_NONE since we don't have a separate padding mask input
- contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;
-
- // Determine broadcast flags for attention_bias (if it exists)
- // The MHA path uses attn_mask as attention_bias (additive bias added before softmax).
- // Bool masks are element-wise converted to additive bias (true → 0.0, false → -inf),
- // preserving the original shape, so the same broadcasting rules apply to both types.
- //
- // ONNX broadcasting is right-aligned to target shape (batch, heads, q_seq, total_seq):
- // 2D [A, B] → [1, 1, A, B] : A = q_seq_len, B = total_seq_len
- // 3D [A, B, C] → [1, A, B, C] : A = heads, B = q_seq_len, C = total_seq_len
- // 4D [A, B, C, D] → [A, B, C, D] : A = batch, B = heads, C = q_seq_len, D = total_seq_len
- //
- // Note: A 2D mask cannot represent per-batch padding because the batch dimension is broadcast.
- // For per-batch boolean padding masks, use 4D shape (batch, 1, 1, total_seq_len).
+ k_data = k_expand_buffer.get();
+ v_data = v_expand_buffer.get();
+ }
+ }
+
+ // Note: MEA with past_key/value is handled by the unfused fallback.
+ // The cascade in ComputeInternal ensures past_key == nullptr when we reach here.
+
+ // Handle attention mask → attention_bias conversion
+ IAllocatorUniquePtr converted_mask_buffer;
+ const void* attn_bias_data = nullptr;
+ bool broadcast_bias_dim_0 = false;
+ bool broadcast_bias_dim_1 = false;
+
+ if (nonpad_kv_seqlen != nullptr) {
+ // Convert nonpad_kv_seqlen to seqlens_k for custom right padding.
+ // MEA expects actual token count (not count-1), so use FlashSeqlensK variant.
+ auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream());
+ ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK(
+ nonpad_kv_seqlen->Data(),
+ seqlens_k_buffer.get(),
+ parameters.batch_size,
+ parameters.total_sequence_length,
+ cuda_stream,
+ device_prop.maxThreadsPerBlock));
+
+ // When attn_mask is also provided, convert it to additive attn_bias so MEA
+ // applies both custom right padding (seqlens_k) and the attention mask (attn_bias).
if (attn_mask != nullptr) {
- size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
- auto attn_mask_dims = attn_mask->Shape().GetDims();
- // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
- // For 3D mask (heads_or_1, q_seq_len, total_seq_len): batch always broadcasts, heads broadcasts if dim[0]==1
- // For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1
-
- if (attn_mask_dims_size == 2) {
- // 2D mask: both dimensions need broadcasting
- contribop_parameters.broadcast_attn_bias_dim_0 = true;
- contribop_parameters.broadcast_attn_bias_dim_1 = true;
- } else if (attn_mask_dims_size == 3) {
- // 3D mask [A, q_seq_len, total_seq_len]: right-aligned to [_, A, q_seq, total_seq]
- // A maps to heads dimension (validated to be 1 or q_num_heads by attention_helper.h)
- // Batch dimension is missing, so always broadcasts
- contribop_parameters.broadcast_attn_bias_dim_0 = true;
- contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1;
- } else {
- // 4D mask: check both dim 0 and dim 1 explicitly
- contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
- contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1;
- }
- } else {
- contribop_parameters.broadcast_attn_bias_dim_0 = false;
- contribop_parameters.broadcast_attn_bias_dim_1 = false;
+ ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream,
+ device_prop.maxThreadsPerBlock,
+ converted_mask_buffer, attn_bias_data,
+ broadcast_bias_dim_0, broadcast_bias_dim_1));
}
- contribop_parameters.mask_filter_value = static_cast