Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 58 additions & 74 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ static Status TransposeBSNHtoBNSH(int batch_size, int sequence_length,

// ============================================================================
// ConvertAttnMaskToBias: shared helper for mask→additive bias conversion.
// Used by both Flash (nonpad+mask) and MEA paths to avoid code duplication.
// Used by the MEA path to convert masks before the CUTLASS kernel call.
// Converts bool masks to additive bias (true→0, false→mask_filter_value),
// passes float masks through directly, and sets broadcast flags from mask shape.
// ============================================================================
Expand Down Expand Up @@ -186,15 +186,12 @@ Status Attention<T>::ConvertAttnMaskToBias(
// Flash Attention dispatch paths:
// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache
// Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache
// - Supports: bool mask -> seqlens_k, no mask -> fill past_seq_len
// - No mask support (attn_mask rejected at eligibility)
// - 4D BNSH: transposes Q/K/V to BSNH before kernel
// Path 3: no past, no mask (prompt) -> mha_fwd
// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk,
// (no mask OR bool mask + past OR nonpad_kv_seqlen without mask)
// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, attn_mask==nullptr
// Note: softcap is passed to the Flash kernel natively. softmax_precision is
// inherently satisfied (Flash accumulates softmax in FP32).
// Note: nonpad_kv_seqlen + attn_mask routes to MEA/unfused, not Flash
// (Flash has no bias parameter for this combination).
//
// PERFORMANCE NOTE: ONNX Attention's internal-cache decode path (past_key/past_value)
// is ~15-30% slower than contrib GQA's decode path for grouped-query attention workloads.
Expand Down Expand Up @@ -227,7 +224,7 @@ template <typename T>
Status Attention<T>::RunFlashAttention(
OpKernelContext* context,
const Tensor* Q, const Tensor* K, const Tensor* V,
const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value,
const Tensor* past_key, const Tensor* past_value,
const Tensor* nonpad_kv_seqlen,
Tensor* Y, Tensor* present_key, Tensor* present_value,
const attention_helper::AttentionParameters& parameters) const {
Expand Down Expand Up @@ -294,6 +291,8 @@ Status Attention<T>::RunFlashAttention(
"(past_sequence_length must be 0, got ",
parameters.past_sequence_length, ").");

// seqlens_k_buffer lifetime: allocated via BFC arena, remains valid for all kernel
// launches on the same CUDA stream until the IAllocatorUniquePtr goes out of scope.
auto seqlens_k_buffer = GetScratchBuffer<int>(parameters.batch_size, GetComputeStream(context));
ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK(
nonpad_kv_seqlen->Data<int64_t>(),
Expand Down Expand Up @@ -348,25 +347,11 @@ Status Attention<T>::RunFlashAttention(

// Step 1: Compute per-batch past sequence lengths for the concat kernel.
// The concat kernel needs past_seq_lens to know where past data ends and new begins.
// attn_mask is always nullptr here (Flash rejects attn_mask), so use uniform past_seq.
auto past_seqlens_buffer = GetScratchBuffer<int>(parameters.batch_size, GetComputeStream(context));
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
size_t mask_dims = attn_mask->Shape().NumDimensions();
auto dims = attn_mask->Shape().GetDims();
int64_t mask_dim0 = dims[0];
int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0;
int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0;
// Offset -kv_seq: mask encodes total valid count; subtract to get past-only count.
int seqlen_offset = -parameters.kv_sequence_length;
ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK(
attn_mask->Data<bool>(), past_seqlens_buffer.get(),
parameters.batch_size, parameters.total_sequence_length,
static_cast<int>(mask_dims), mask_dim0, mask_dim1, mask_dim2,
cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset));
} else {
ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length,
parameters.batch_size, cuda_stream,
device_prop.maxThreadsPerBlock));
}
ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length,
parameters.batch_size, cuda_stream,
device_prop.maxThreadsPerBlock));

// Step 2: Transpose K/V to BSNH if input is 4D BNSH (concat kernel reads new as BSNH).
const T* k_new_bsnh = K->Data<T>();
Expand Down Expand Up @@ -399,11 +384,7 @@ Status Attention<T>::RunFlashAttention(
// into present buffer at [past_seq, past_seq + kv_seq), all in BNSH.
// Note: is_bsnh=false means past/present cache layout is BNSH. New tokens
// (k_new_bsnh/v_new_bsnh) are always read as BSNH by the kernel (hardcoded strides).
// NOTE: When bool masks produce variable per-batch past_seq_lens, positions in the range
// [past_seq_lens[b] + kv_sequence_length, total_sequence_length) for each batch b are left
// uninitialized by the concat kernel. Flash Attention reads only up to seqlens_k[b] positions
// per batch, so these values are never accessed. In the no-mask case (uniform past_seq_lens),
// every position in the present buffer is written.
// past_seqlens is uniform (no mask) so every position in the present buffer is written.
ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV<NativeCudaT>(
parameters.batch_size,
parameters.kv_num_heads,
Expand All @@ -427,27 +408,13 @@ Status Attention<T>::RunFlashAttention(
// Step 4: Compute total seqlens for mha_fwd_kvcache.
// With k_new=nullptr, the kernel treats seqlens_k as the total valid token count
// (not pre-append count), so we need past + new.
// attn_mask is always nullptr here (Flash rejects attn_mask), so use uniform seqlens.
auto seqlens_k_buffer = GetScratchBuffer<int>(parameters.batch_size, GetComputeStream(context));
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
size_t mask_dims = attn_mask->Shape().NumDimensions();
auto dims = attn_mask->Shape().GetDims();
int64_t mask_dim0 = dims[0];
int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0;
int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0;
// Offset 0: mask encodes total valid count, which is exactly what we need.
int seqlen_offset = 0;
ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK(
attn_mask->Data<bool>(), seqlens_k_buffer.get(),
parameters.batch_size, parameters.total_sequence_length,
static_cast<int>(mask_dims), mask_dim0, mask_dim1, mask_dim2,
cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset));
} else {
ORT_RETURN_IF_ERROR(LaunchFillInt32(
seqlens_k_buffer.get(),
parameters.past_sequence_length + parameters.kv_sequence_length,
parameters.batch_size, cuda_stream,
device_prop.maxThreadsPerBlock));
}
ORT_RETURN_IF_ERROR(LaunchFillInt32(
seqlens_k_buffer.get(),
parameters.past_sequence_length + parameters.kv_sequence_length,
parameters.batch_size, cuda_stream,
device_prop.maxThreadsPerBlock));

// Step 5: Flash attention on pre-populated cache.
// k_new=nullptr tells mha_fwd_kvcache to skip its internal Append_KV — the cache
Expand Down Expand Up @@ -542,7 +509,6 @@ Status Attention<T>::RunFlashAttention(
ORT_UNUSED_PARAMETER(Q);
ORT_UNUSED_PARAMETER(K);
ORT_UNUSED_PARAMETER(V);
ORT_UNUSED_PARAMETER(attn_mask);
ORT_UNUSED_PARAMETER(past_key);
ORT_UNUSED_PARAMETER(past_value);
ORT_UNUSED_PARAMETER(nonpad_kv_seqlen);
Expand Down Expand Up @@ -730,6 +696,22 @@ Status Attention<T>::RunMemoryEfficientAttention(
p.workspace = nullptr;
}
onnxruntime::contrib::cuda::run_memory_efficient_attention(p);

// On the MEA (CUTLASS) path (used for both MHA and GQA when nonpad_kv_seqlen is provided),
// zero out output for fully-masked batches to produce zeros (matching Flash behavior).
// CUTLASS epilogue computes 1/s_prime where s_prime=0 for seqlens_k=0, producing NaN.
{
using CudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t elements_per_batch = static_cast<int64_t>(parameters.q_sequence_length) *
parameters.q_num_heads * parameters.v_head_size;
ORT_RETURN_IF_ERROR(LaunchZeroOutputForFullyMaskedBatches<CudaT>(
reinterpret_cast<CudaT*>(out_data),
seqlens_k_buffer.get(),
Comment thread
tianleiwu marked this conversation as resolved.
parameters.batch_size,
elements_per_batch,
cuda_stream,
device_prop.maxThreadsPerBlock));
}
Comment thread
titaiwangms marked this conversation as resolved.
}
// Standard MEA path: float attention bias, bool mask (converted to bias), or no mask.
// Bool masks are converted to additive attention bias (true→0, false→mask_filter_value)
Expand Down Expand Up @@ -858,6 +840,8 @@ Status Attention<T>::RunUnfusedAttention(
Tensor* output_qk,
const attention_helper::AttentionParameters& parameters) const {
using CudaT = typename ToCudaType<T>::MappedType;
// OrtToCudaType maps BFloat16 → __nv_bfloat16 (native HW type), matching kernel instantiations.
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
auto& device_prop = GetDeviceProp();
auto cuda_stream = Stream(context);
auto ort_stream = GetOrtStream(context);
Expand Down Expand Up @@ -938,7 +922,6 @@ Status Attention<T>::RunUnfusedAttention(
IAllocatorUniquePtr<void> mask_bias_buffer; // temp buffer for mask→bias when composing
if (nonpad_kv_seqlen != nullptr) {
// Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq]
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t bias_elements = static_cast<int64_t>(parameters.batch_size) *
parameters.q_sequence_length *
parameters.total_sequence_length;
Expand Down Expand Up @@ -1004,7 +987,6 @@ Status Attention<T>::RunUnfusedAttention(
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else if (attn_mask != nullptr) {
if (attn_mask->IsDataType<bool>()) {
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t num_elements = attn_mask->Shape().Size();
converted_mask_buffer = GetScratchBuffer<void>(num_elements * sizeof(NativeCudaT), GetComputeStream(context));
ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias<NativeCudaT>(
Expand Down Expand Up @@ -1049,6 +1031,9 @@ Status Attention<T>::RunUnfusedAttention(
cublasHandle_t cublas = GetCublasHandle(context);
cudnnHandle_t cudnn = GetCudnnHandle(context);

// Note: unfused attention produces valid finite output (mean-of-V via uniform softmax)
// for fully-masked batches, so ZeroOutput is not needed here. Only MEA requires
// ZeroOutput to prevent NaN from the CUTLASS epilogue's 1/s_prime division.
return onnxruntime::contrib::cuda::QkvToContext<CudaT, CudaT>(
device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data);
}
Expand Down Expand Up @@ -1134,20 +1119,17 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.q_num_heads, parameters.kv_num_heads) &&
parameters.head_size == parameters.v_head_size &&
!has_output_qk &&
// Bool masks without past_key (prompt) can't use flash because mha_fwd_kvcache's
// causal semantics are decode-oriented (window offset by seqlens_k). For causal
// prompt with padding, MEA handles it correctly via attention bias conversion.
// Flash handles: no mask, decode with past (±mask), nonpad_kv_seqlen.
// Note: contrib MHA similarly excludes flash when attention_bias is present
// (no mask support in mha_fwd). Float masks and bool prompt masks route to MEA
// which supports additive bias natively.
(attn_mask == nullptr || (attn_mask->IsDataType<bool>() && past_key != nullptr)) &&
// Flash cannot handle nonpad_kv_seqlen + attn_mask simultaneously (no bias parameter
// in mha_fwd/mha_fwd_kvcache when seqlens_k is used). Route to MEA instead.
!(nonpad_kv_seqlen != nullptr && attn_mask != nullptr);
// Flash does not support attention masks (no bias parameter in mha_fwd/mha_fwd_kvcache).
// Bool attn_mask + past_key is rejected because Flash uses paged KV cache semantics
// that produce spec-divergent present_kv layout for partial masks (e.g. [T,T,T,F]).
// Unfused handles bool+past_key spec-correctly via standard ConcatPastToPresent.
// TODO(titaiwang): GQA + bool attn_mask + past_key currently has no runner (Flash
// rejected here, unfused doesn't support GQA, MEA blocked by past_key != nullptr).
Comment thread
tianleiwu marked this conversation as resolved.
// Once PR #27851 merges (MEA supports past_key), this gap will be covered.
attn_mask == nullptr;

if (flash_eligible) {
return RunFlashAttention(context, Q, K, V, attn_mask, past_key, past_value,
return RunFlashAttention(context, Q, K, V, past_key, past_value,
nonpad_kv_seqlen, Y, present_key, present_value, parameters);
}
}
Expand All @@ -1171,13 +1153,14 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// total_sequence_length. Skip MEA if this stride can't satisfy the kernel's
// minimum alignment requirement.
if (mea_eligible && attn_mask != nullptr) {
int min_bias_align = 1;
if ((std::is_same<T, float>::value && sm >= 80) ||
(!std::is_same<T, float>::value && sm >= 75)) {
min_bias_align = 4; // TensorOp on Sm80+ (float) or Sm75+ (fp16/bf16)
} else if (!std::is_same<T, float>::value && sm >= 70) {
min_bias_align = 2; // TensorOp on Volta (fp16)
}
// NOTE: CUTLASS uses kMinimumAlignment = 4 (elements, not bytes) for the bias
// pointer in its epilogue. total_sequence_length is the bias row stride in elements,
// so we check alignment in element count. The contrib_ops convention (4 * sizeof(T))
// conflates bytes with elements; we use the correct value of 4 elements here.
// Note: on SM50/53 (Maxwell), CUTLASS kMinimumAlignment=1, so this is stricter than
// necessary — cases with odd total_sequence_length that previously used MEA on those
// GPUs will now fall to unfused. This is acceptable for these very old architectures.
constexpr int min_bias_align = 4;
Comment thread
tianleiwu marked this conversation as resolved.
if (parameters.total_sequence_length % min_bias_align != 0) {
mea_eligible = false;
}
Expand Down Expand Up @@ -1215,8 +1198,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// to replicate kv_num_heads -> q_num_heads before unfused can process.
// Requires ~160 lines. See issue #27516.
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GQA (q_num_heads != kv_num_heads) requires flash or memory efficient attention, "
"but neither is eligible. Ensure fp16/bf16 on Ampere+ GPU, or check head_size constraints.");
"ONNX Attention with GQA (q_num_heads != kv_num_heads) is not supported by the "
"unfused runner. Flash requires fp16/bf16, SM>=80, and attn_mask==nullptr; MEA "
"requires past_key==nullptr. See PR #27851 for MEA past_key support.");
}

return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/llm/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Attention final : public CudaKernel {
Status RunFlashAttention(
OpKernelContext* context,
const Tensor* Q, const Tensor* K, const Tensor* V,
const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value,
const Tensor* past_key, const Tensor* past_value,
const Tensor* nonpad_kv_seqlen,
Tensor* Y, Tensor* present_key, Tensor* present_value,
const attention_helper::AttentionParameters& parameters) const;
Expand Down
Loading
Loading