Skip to content
177 changes: 140 additions & 37 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,12 @@ Status Attention<T>::RunFlashAttention(
// ============================================================================
//
// Memory Efficient Attention (cutlass FMHA) dispatch paths:
// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode
// Path 2: no past, with mask (prompt) -> standard MEA with additive bias
// Path 3: no past, no mask (prompt) -> standard MEA
// Path 1: Decode with past KV cache -> LaunchConcatNewToPastKV then standard MEA
// Path 2: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode
// Path 3: Prompt with mask -> standard MEA with additive bias
// Path 4: Prompt without mask -> standard MEA
// Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype,
// head_size <= 1024), plus: no output_qk, no past_key (decode excluded),
// bias stride alignment.
// head_size <= 1024), plus: no output_qk, bias stride alignment.
// Note: softcap is forwarded to the MEA kernel via p.softcap. softmax_precision
// is inherently satisfied (cutlass FMHA accumulates softmax in FP32).
//
Expand All @@ -564,8 +564,6 @@ Status Attention<T>::RunMemoryEfficientAttention(
Tensor* Y, Tensor* present_key, Tensor* present_value,
const attention_helper::AttentionParameters& parameters) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
ORT_UNUSED_PARAMETER(past_key);
ORT_UNUSED_PARAMETER(past_value);
auto& device_prop = GetDeviceProp();
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
const bool is_bsnh = parameters.transpose_output;
Expand Down Expand Up @@ -600,6 +598,106 @@ Status Attention<T>::RunMemoryEfficientAttention(
out_data = out_bsnh_buffer.get();
}

bool present_kv_already_populated = false;
// Track the effective layout of k_data/v_data. Initially matches input layout,
// but changes to BNSH (false) after decode concat into present buffers.
bool kv_is_bsnh = is_bsnh;

// --- Decode path: concat past + new K/V → present buffers (BNSH) ---
// nonpad_kv_seqlen and past_key are mutually exclusive (enforced at validation),
// so the decode path only needs the internal-cache (past_key/present_key) flow.
if (past_key != nullptr) {
ORT_ENFORCE(past_value != nullptr, "past_key requires past_value.");
ORT_ENFORCE(present_key != nullptr && present_value != nullptr,
"present_key/value outputs are required when past_key is provided.");
ORT_ENFORCE(parameters.head_size == parameters.v_head_size,
"MEA decode (past_key) requires head_size == v_head_size for LaunchConcatNewToPastKV.");

using NativeCudaT = typename OrtToCudaType<T>::type;

// Step 1: Compute per-batch past sequence lengths for the concat kernel.
auto past_seqlens_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
size_t mask_dims = attn_mask->Shape().NumDimensions();
auto dims = attn_mask->Shape().GetDims();
int64_t mask_dim0 = dims[0];
int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0;
int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0;
// Offset -kv_seq: mask encodes total valid count; subtract to get past-only count.
int seqlen_offset = -parameters.kv_sequence_length;
ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK(
attn_mask->Data<bool>(), past_seqlens_buffer.get(),
parameters.batch_size, parameters.total_sequence_length,
static_cast<int>(mask_dims), mask_dim0, mask_dim1, mask_dim2,
cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset));
} else {
ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length,
parameters.batch_size, cuda_stream,
device_prop.maxThreadsPerBlock));
}

// Step 2: Transpose K/V to BSNH if 4D BNSH (concat kernel reads new tokens as BSNH).
const T* k_new_bsnh = K->Data<T>();
const T* v_new_bsnh = V->Data<T>();
IAllocatorUniquePtr<void> k_bsnh_buffer;
IAllocatorUniquePtr<void> v_bsnh_buffer;
if (!is_bsnh) {
size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length *
parameters.kv_num_heads * parameters.head_size;
size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length *
parameters.kv_num_heads * parameters.v_head_size;
k_bsnh_buffer = GetScratchBuffer<void>(k_bytes, context->GetComputeStream());
v_bsnh_buffer = GetScratchBuffer<void>(v_bytes, context->GetComputeStream());
ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.head_size,
K->Data<T>(), k_bsnh_buffer.get(),
cuda_stream, device_prop.maxThreadsPerBlock));
ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.v_head_size,
V->Data<T>(), v_bsnh_buffer.get(),
cuda_stream, device_prop.maxThreadsPerBlock));
k_new_bsnh = static_cast<const T*>(k_bsnh_buffer.get());
v_new_bsnh = static_cast<const T*>(v_bsnh_buffer.get());
}

// Step 3: Fused concat: past_key + new_key → present_key (BNSH).
// When bool masks produce variable per-batch past_seq_lens, positions in the range
// [past_seq_lens[b] + kv_sequence_length, total_sequence_length) are not written by
// the concat kernel. Zero the buffers first to prevent NaN propagation — MEA reads
// all positions (masked by additive bias), unlike Flash which bounds reads via seqlens_k.
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_key->MutableData<T>(), 0,
present_key->SizeInBytes(), cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_value->MutableData<T>(), 0,
present_value->SizeInBytes(), cuda_stream));
ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV<NativeCudaT>(
parameters.batch_size,
parameters.kv_num_heads,
parameters.head_size,
parameters.kv_sequence_length,
parameters.past_sequence_length,
parameters.total_sequence_length,
/*is_bsnh=*/false,
past_seqlens_buffer.get(),
/*total_seq_lens=*/nullptr,
reinterpret_cast<const NativeCudaT*>(past_key->Data<T>()),
reinterpret_cast<const NativeCudaT*>(past_value->Data<T>()),
reinterpret_cast<const NativeCudaT*>(k_new_bsnh),
reinterpret_cast<const NativeCudaT*>(v_new_bsnh),
reinterpret_cast<NativeCudaT*>(present_key->MutableData<T>()),
reinterpret_cast<NativeCudaT*>(present_value->MutableData<T>()),
cuda_stream,
device_prop.maxThreadsPerBlock,
/*past_only=*/false));

// Point MEA's K/V inputs at the concatenated present buffers (BNSH).
k_data = present_key->Data<T>();
v_data = present_value->Data<T>();
kv_is_bsnh = false;
present_kv_already_populated = true;
}

// GQA head expansion: MEA requires matching num_heads for Q/K/V.
// When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup.
const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads;
Expand Down Expand Up @@ -640,7 +738,7 @@ Status Attention<T>::RunMemoryEfficientAttention(
reinterpret_cast<const float2*>(v_data),
parameters.total_sequence_length,
parameters.total_sequence_length,
is_bsnh,
kv_is_bsnh,
cuda_stream,
device_prop.maxThreadsPerBlock));

Expand All @@ -649,8 +747,8 @@ Status Attention<T>::RunMemoryEfficientAttention(
}
}

// Note: MEA with past_key/value is handled by the unfused fallback.
// The cascade in ComputeInternal ensures past_key == nullptr when we reach here.
// Note: When past_key is present (decode), k_data/v_data already point to present
// buffers (BNSH) after LaunchConcatNewToPastKV above, so MEA sees the full cache.

// Handle attention mask → attention_bias conversion
IAllocatorUniquePtr<void> converted_mask_buffer;
Expand Down Expand Up @@ -683,7 +781,7 @@ Status Attention<T>::RunMemoryEfficientAttention(
p.sm = sm;
p.is_half = std::is_same<T, MLFloat16>::value;
p.is_bf16 = std::is_same<T, BFloat16>::value;
p.is_kv_bsnh = is_bsnh;
p.is_kv_bsnh = kv_is_bsnh;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.q_num_heads;
p.sequence_length = parameters.q_sequence_length;
Expand Down Expand Up @@ -733,7 +831,7 @@ Status Attention<T>::RunMemoryEfficientAttention(
p.sm = sm;
p.is_half = std::is_same<T, MLFloat16>::value;
p.is_bf16 = std::is_same<T, BFloat16>::value;
p.is_kv_bsnh = is_bsnh;
p.is_kv_bsnh = kv_is_bsnh;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.q_num_heads;
p.sequence_length = parameters.q_sequence_length;
Expand Down Expand Up @@ -775,30 +873,33 @@ Status Attention<T>::RunMemoryEfficientAttention(
cuda_stream, device_prop.maxThreadsPerBlock));
}

// Populate present_key/present_value (BNSH) if requested
if (present_key != nullptr && is_bsnh) {
ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.head_size,
K->Data<T>(), present_key->MutableData<T>(),
cuda_stream, device_prop.maxThreadsPerBlock));
} else if (present_key != nullptr && !is_bsnh) {
// 4D BNSH prompt: K is already BNSH, just D2D copy to present
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
present_key->MutableData<T>(), K->Data<T>(),
K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
}
if (present_value != nullptr && is_bsnh) {
ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.v_head_size,
V->Data<T>(), present_value->MutableData<T>(),
cuda_stream, device_prop.maxThreadsPerBlock));
} else if (present_value != nullptr && !is_bsnh) {
// 4D BNSH prompt: V is already BNSH, just D2D copy to present
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
present_value->MutableData<T>(), V->Data<T>(),
V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
// Populate present_key/present_value (BNSH) if requested.
// Skip for decode path where LaunchConcatNewToPastKV already populated present buffers.
if (!present_kv_already_populated) {
if (present_key != nullptr && is_bsnh) {
ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.head_size,
K->Data<T>(), present_key->MutableData<T>(),
cuda_stream, device_prop.maxThreadsPerBlock));
} else if (present_key != nullptr && !is_bsnh) {
// 4D BNSH prompt: K is already BNSH, just D2D copy to present
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
present_key->MutableData<T>(), K->Data<T>(),
K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
}
if (present_value != nullptr && is_bsnh) {
ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH<T>(
parameters.batch_size, parameters.kv_sequence_length,
parameters.kv_num_heads, parameters.v_head_size,
V->Data<T>(), present_value->MutableData<T>(),
cuda_stream, device_prop.maxThreadsPerBlock));
} else if (present_value != nullptr && !is_bsnh) {
// 4D BNSH prompt: V is already BNSH, just D2D copy to present
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
present_value->MutableData<T>(), V->Data<T>(),
V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream));
}
}

return Status::OK();
Expand Down Expand Up @@ -1148,7 +1249,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value,
parameters.head_size, parameters.v_head_size) &&
!has_output_qk &&
past_key == nullptr;
// MEA decode requires head_size == v_head_size for LaunchConcatNewToPastKV
// (single head_size parameter). Fall back to unfused when they differ.
!(past_key != nullptr && parameters.head_size != parameters.v_head_size);

// Cutlass FMHA requires bias strides to satisfy minimum alignment even in the
// "unaligned" kernel path. When an attention mask is present (with or without
Expand Down
Loading
Loading