diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 27b81691c70af..0c2a646278e65 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -134,7 +134,7 @@ static Status TransposeBSNHtoBNSH(int batch_size, int sequence_length, // ============================================================================ // ConvertAttnMaskToBias: shared helper for mask→additive bias conversion. -// Used by both Flash (nonpad+mask) and MEA paths to avoid code duplication. +// Used by the MEA path to convert masks before the CUTLASS kernel call. // Converts bool masks to additive bias (true→0, false→mask_filter_value), // passes float masks through directly, and sets broadcast flags from mask shape. // ============================================================================ @@ -186,15 +186,12 @@ Status Attention::ConvertAttnMaskToBias( // Flash Attention dispatch paths: // Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache // Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache -// - Supports: bool mask -> seqlens_k, no mask -> fill past_seq_len +// - No mask support (attn_mask rejected at eligibility) // - 4D BNSH: transposes Q/K/V to BSNH before kernel // Path 3: no past, no mask (prompt) -> mha_fwd -// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, -// (no mask OR bool mask + past OR nonpad_kv_seqlen without mask) +// Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, attn_mask==nullptr // Note: softcap is passed to the Flash kernel natively. softmax_precision is // inherently satisfied (Flash accumulates softmax in FP32). -// Note: nonpad_kv_seqlen + attn_mask routes to MEA/unfused, not Flash -// (Flash has no bias parameter for this combination). // // PERFORMANCE NOTE: ONNX Attention's internal-cache decode path (past_key/past_value) // is ~15-30% slower than contrib GQA's decode path for grouped-query attention workloads. @@ -227,7 +224,7 @@ template Status Attention::RunFlashAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const { @@ -294,6 +291,8 @@ Status Attention::RunFlashAttention( "(past_sequence_length must be 0, got ", parameters.past_sequence_length, ")."); + // seqlens_k_buffer lifetime: allocated via BFC arena, remains valid for all kernel + // launches on the same CUDA stream until the IAllocatorUniquePtr goes out of scope. auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), @@ -348,25 +347,11 @@ Status Attention::RunFlashAttention( // Step 1: Compute per-batch past sequence lengths for the concat kernel. // The concat kernel needs past_seq_lens to know where past data ends and new begins. + // attn_mask is always nullptr here (Flash rejects attn_mask), so use uniform past_seq. auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); - if (attn_mask != nullptr && attn_mask->IsDataType()) { - size_t mask_dims = attn_mask->Shape().NumDimensions(); - auto dims = attn_mask->Shape().GetDims(); - int64_t mask_dim0 = dims[0]; - int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0; - int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0; - // Offset -kv_seq: mask encodes total valid count; subtract to get past-only count. - int seqlen_offset = -parameters.kv_sequence_length; - ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK( - attn_mask->Data(), past_seqlens_buffer.get(), - parameters.batch_size, parameters.total_sequence_length, - static_cast(mask_dims), mask_dim0, mask_dim1, mask_dim2, - cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset)); - } else { - ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, - parameters.batch_size, cuda_stream, - device_prop.maxThreadsPerBlock)); - } + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); // Step 2: Transpose K/V to BSNH if input is 4D BNSH (concat kernel reads new as BSNH). const T* k_new_bsnh = K->Data(); @@ -399,11 +384,7 @@ Status Attention::RunFlashAttention( // into present buffer at [past_seq, past_seq + kv_seq), all in BNSH. // Note: is_bsnh=false means past/present cache layout is BNSH. New tokens // (k_new_bsnh/v_new_bsnh) are always read as BSNH by the kernel (hardcoded strides). - // NOTE: When bool masks produce variable per-batch past_seq_lens, positions in the range - // [past_seq_lens[b] + kv_sequence_length, total_sequence_length) for each batch b are left - // uninitialized by the concat kernel. Flash Attention reads only up to seqlens_k[b] positions - // per batch, so these values are never accessed. In the no-mask case (uniform past_seq_lens), - // every position in the present buffer is written. + // past_seqlens is uniform (no mask) so every position in the present buffer is written. ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( parameters.batch_size, parameters.kv_num_heads, @@ -427,27 +408,13 @@ Status Attention::RunFlashAttention( // Step 4: Compute total seqlens for mha_fwd_kvcache. // With k_new=nullptr, the kernel treats seqlens_k as the total valid token count // (not pre-append count), so we need past + new. + // attn_mask is always nullptr here (Flash rejects attn_mask), so use uniform seqlens. auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); - if (attn_mask != nullptr && attn_mask->IsDataType()) { - size_t mask_dims = attn_mask->Shape().NumDimensions(); - auto dims = attn_mask->Shape().GetDims(); - int64_t mask_dim0 = dims[0]; - int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0; - int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0; - // Offset 0: mask encodes total valid count, which is exactly what we need. - int seqlen_offset = 0; - ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK( - attn_mask->Data(), seqlens_k_buffer.get(), - parameters.batch_size, parameters.total_sequence_length, - static_cast(mask_dims), mask_dim0, mask_dim1, mask_dim2, - cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset)); - } else { - ORT_RETURN_IF_ERROR(LaunchFillInt32( - seqlens_k_buffer.get(), - parameters.past_sequence_length + parameters.kv_sequence_length, - parameters.batch_size, cuda_stream, - device_prop.maxThreadsPerBlock)); - } + ORT_RETURN_IF_ERROR(LaunchFillInt32( + seqlens_k_buffer.get(), + parameters.past_sequence_length + parameters.kv_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); // Step 5: Flash attention on pre-populated cache. // k_new=nullptr tells mha_fwd_kvcache to skip its internal Append_KV — the cache @@ -542,7 +509,6 @@ Status Attention::RunFlashAttention( ORT_UNUSED_PARAMETER(Q); ORT_UNUSED_PARAMETER(K); ORT_UNUSED_PARAMETER(V); - ORT_UNUSED_PARAMETER(attn_mask); ORT_UNUSED_PARAMETER(past_key); ORT_UNUSED_PARAMETER(past_value); ORT_UNUSED_PARAMETER(nonpad_kv_seqlen); @@ -730,6 +696,22 @@ Status Attention::RunMemoryEfficientAttention( p.workspace = nullptr; } onnxruntime::contrib::cuda::run_memory_efficient_attention(p); + + // On the MEA (CUTLASS) path (used for both MHA and GQA when nonpad_kv_seqlen is provided), + // zero out output for fully-masked batches to produce zeros (matching Flash behavior). + // CUTLASS epilogue computes 1/s_prime where s_prime=0 for seqlens_k=0, producing NaN. + { + using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; + int64_t elements_per_batch = static_cast(parameters.q_sequence_length) * + parameters.q_num_heads * parameters.v_head_size; + ORT_RETURN_IF_ERROR(LaunchZeroOutputForFullyMaskedBatches( + reinterpret_cast(out_data), + seqlens_k_buffer.get(), + parameters.batch_size, + elements_per_batch, + cuda_stream, + device_prop.maxThreadsPerBlock)); + } } // Standard MEA path: float attention bias, bool mask (converted to bias), or no mask. // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value) @@ -858,6 +840,8 @@ Status Attention::RunUnfusedAttention( Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const { using CudaT = typename ToCudaType::MappedType; + // OrtToCudaType maps BFloat16 → __nv_bfloat16 (native HW type), matching kernel instantiations. + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; auto& device_prop = GetDeviceProp(); auto cuda_stream = Stream(context); auto ort_stream = GetOrtStream(context); @@ -938,7 +922,6 @@ Status Attention::RunUnfusedAttention( IAllocatorUniquePtr mask_bias_buffer; // temp buffer for mask→bias when composing if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq] - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t bias_elements = static_cast(parameters.batch_size) * parameters.q_sequence_length * parameters.total_sequence_length; @@ -1004,7 +987,6 @@ Status Attention::RunUnfusedAttention( contribop_parameters.broadcast_attn_bias_dim_1 = true; } else if (attn_mask != nullptr) { if (attn_mask->IsDataType()) { - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( @@ -1049,6 +1031,9 @@ Status Attention::RunUnfusedAttention( cublasHandle_t cublas = GetCublasHandle(context); cudnnHandle_t cudnn = GetCudnnHandle(context); + // Note: unfused attention produces valid finite output (mean-of-V via uniform softmax) + // for fully-masked batches, so ZeroOutput is not needed here. Only MEA requires + // ZeroOutput to prevent NaN from the CUTLASS epilogue's 1/s_prime division. return onnxruntime::contrib::cuda::QkvToContext( device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); } @@ -1134,20 +1119,17 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.q_num_heads, parameters.kv_num_heads) && parameters.head_size == parameters.v_head_size && !has_output_qk && - // Bool masks without past_key (prompt) can't use flash because mha_fwd_kvcache's - // causal semantics are decode-oriented (window offset by seqlens_k). For causal - // prompt with padding, MEA handles it correctly via attention bias conversion. - // Flash handles: no mask, decode with past (±mask), nonpad_kv_seqlen. - // Note: contrib MHA similarly excludes flash when attention_bias is present - // (no mask support in mha_fwd). Float masks and bool prompt masks route to MEA - // which supports additive bias natively. - (attn_mask == nullptr || (attn_mask->IsDataType() && past_key != nullptr)) && - // Flash cannot handle nonpad_kv_seqlen + attn_mask simultaneously (no bias parameter - // in mha_fwd/mha_fwd_kvcache when seqlens_k is used). Route to MEA instead. - !(nonpad_kv_seqlen != nullptr && attn_mask != nullptr); + // Flash does not support attention masks (no bias parameter in mha_fwd/mha_fwd_kvcache). + // Bool attn_mask + past_key is rejected because Flash uses paged KV cache semantics + // that produce spec-divergent present_kv layout for partial masks (e.g. [T,T,T,F]). + // Unfused handles bool+past_key spec-correctly via standard ConcatPastToPresent. + // TODO(titaiwang): GQA + bool attn_mask + past_key currently has no runner (Flash + // rejected here, unfused doesn't support GQA, MEA blocked by past_key != nullptr). + // Once PR #27851 merges (MEA supports past_key), this gap will be covered. + attn_mask == nullptr; if (flash_eligible) { - return RunFlashAttention(context, Q, K, V, attn_mask, past_key, past_value, + return RunFlashAttention(context, Q, K, V, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } } @@ -1171,13 +1153,14 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // total_sequence_length. Skip MEA if this stride can't satisfy the kernel's // minimum alignment requirement. if (mea_eligible && attn_mask != nullptr) { - int min_bias_align = 1; - if ((std::is_same::value && sm >= 80) || - (!std::is_same::value && sm >= 75)) { - min_bias_align = 4; // TensorOp on Sm80+ (float) or Sm75+ (fp16/bf16) - } else if (!std::is_same::value && sm >= 70) { - min_bias_align = 2; // TensorOp on Volta (fp16) - } + // NOTE: CUTLASS uses kMinimumAlignment = 4 (elements, not bytes) for the bias + // pointer in its epilogue. total_sequence_length is the bias row stride in elements, + // so we check alignment in element count. The contrib_ops convention (4 * sizeof(T)) + // conflates bytes with elements; we use the correct value of 4 elements here. + // Note: on SM50/53 (Maxwell), CUTLASS kMinimumAlignment=1, so this is stricter than + // necessary — cases with odd total_sequence_length that previously used MEA on those + // GPUs will now fall to unfused. This is acceptable for these very old architectures. + constexpr int min_bias_align = 4; if (parameters.total_sequence_length % min_bias_align != 0) { mea_eligible = false; } @@ -1215,8 +1198,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // to replicate kv_num_heads -> q_num_heads before unfused can process. // Requires ~160 lines. See issue #27516. return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "GQA (q_num_heads != kv_num_heads) requires flash or memory efficient attention, " - "but neither is eligible. Ensure fp16/bf16 on Ampere+ GPU, or check head_size constraints."); + "ONNX Attention with GQA (q_num_heads != kv_num_heads) is not supported by the " + "unfused runner. Flash requires fp16/bf16, SM>=80, and attn_mask==nullptr; MEA " + "requires past_key==nullptr. See PR #27851 for MEA past_key support."); } return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index 3f69c7a77f497..c53c5c80d61e2 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -18,7 +18,7 @@ class Attention final : public CudaKernel { Status RunFlashAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const; diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index e3ede1a0bbc9e..4ab3990b2f85d 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -7,144 +7,6 @@ namespace onnxruntime { namespace cuda { -// CUDA kernel to convert boolean attention mask to sequence lengths. -// Also validates that the mask follows right-padding convention. -// -// The kernel processes one batch per thread. -// For each batch, it finds the first False in the mask row, which indicates -// where padding starts. The sequence length is the index of first False. -// -// Validation: -// - All-false masks are valid (represents fully masked / zero-length sequence) -// - After the first False, all remaining elements must be False (contiguous padding) -// - CUDA_KERNEL_ASSERT fires in debug builds if mask is non-contiguous -// - In release builds, non-contiguous masks produce safe output: seqlens_k is the -// count of leading True values (up to first False), ignoring later True values -// -// Handle broadcasting: -// - 2D mask (q_seq_len, total_seq_len): broadcasts over batch; uses first query position (row 0) -// - 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts to [1, num_heads, q_seq, total_seq] -// No per-batch variation; uses first head, first q position for all batches -// - 4D mask (B, H, q_seq_len, total_seq_len): we look at first head, first q position -__global__ void ConvertMaskToSeqlensKernel( - const bool* __restrict__ attn_mask, - int* __restrict__ seqlens_k, - const int batch_size, - const int total_seq_len, - const int mask_dims, - const int64_t mask_dim0, - const int64_t mask_dim1, - const int64_t mask_dim2, - const int seqlen_offset) { - int batch_idx = threadIdx.x + blockIdx.x * blockDim.x; - if (batch_idx >= batch_size) { - return; - } - - // Calculate the starting offset for this batch's mask row - // We need to figure out which row of the mask to use based on broadcasting rules - const bool* mask_row = nullptr; - - if (mask_dims == 2) { - // Shape: (q_seq_len, total_seq_len) per ONNX spec. Broadcasts over batch. - // Use first query position (row 0) for sequence length determination. - // For 2D masks [q_seq, total_seq], only used in decode path where q_seq=1, - // so row 0 is always correct. Flash excludes 2D bool masks for prompt. - mask_row = attn_mask; - } else if (mask_dims == 3) { - // Shape: (num_heads, q_seq_len, total_seq_len) - // This broadcasts to [1, num_heads, q_seq, total_seq] - same mask for all batches - // We look at first head (h_idx = 0) and first q position (q_idx = 0) - int h_idx = 0; // First head - int q_idx = 0; // First query position - // Stride: q_seq_len * total_seq_len per head - int64_t head_stride = mask_dim1 * total_seq_len; // mask_dim1 = q_seq_len - int64_t q_stride = total_seq_len; - // Same mask row for all batches since 3D has no batch dimension - mask_row = attn_mask + h_idx * head_stride + q_idx * q_stride; - } else { - // 4D: Shape (B, H, q_seq_len, total_seq_len) - // B could be batch_size or 1 (broadcast) - // H could be num_heads or 1 (broadcast) - // We look at first head (h_idx = 0) and first q position (q_idx = 0) - int effective_batch = (mask_dim0 == 1) ? 0 : batch_idx; - int h_idx = 0; // First head - int q_idx = 0; // First query position - // Strides - int64_t batch_stride = mask_dim1 * mask_dim2 * total_seq_len; - int64_t head_stride = mask_dim2 * total_seq_len; - int64_t q_stride = total_seq_len; - mask_row = attn_mask + effective_batch * batch_stride + h_idx * head_stride + q_idx * q_stride; - } - - // Find the first False (where padding starts) - // All elements before first False must be True, all after must be False (right-padding convention) - int seq_len; - if (!mask_row[0]) { - // Entire row is padding (all-false mask) - seq_len = 0; - } else { - seq_len = total_seq_len; // Default: all True (no padding) - bool found_first_false = false; - - for (int i = 1; i < total_seq_len; ++i) { - bool current = mask_row[i]; - - if (!found_first_false && !current) { - // Found first False - this is where padding starts - seq_len = i; - found_first_false = true; - } else if (found_first_false && current) { - // Found True after False - mask is not contiguous (invalid) - CUDA_KERNEL_ASSERT(false); // mask must be contiguous (no True after False) - break; // Safe: seq_len already reflects leading-True count - } - } - } - - // seqlens_k output: seq_len + seqlen_offset - // Decode with past (seqlen_offset=-kv_seq_len): pre-append cache count - // Prompt/MEA (seqlen_offset=0): actual token count - // Clamp to 0: all-false mask (seq_len=0) with negative decode offset - // would produce negative seqlens_k, which is undefined in Flash kernels. - seqlens_k[batch_idx] = max(0, seq_len + seqlen_offset); -} - -// Convert boolean mask to sequence lengths with a configurable offset. -// seqlens_k[b] = num_true_tokens + seqlen_offset -Status LaunchConvertMaskToFlashSeqlensK( - const bool* attn_mask_bool, - int* seqlens_k, - int batch_size, - int total_seq_len, - int mask_dims, - int64_t mask_dim0, - int64_t mask_dim1, - int64_t mask_dim2, - cudaStream_t stream, - int max_threads_per_block, - int seqlen_offset) { - if (batch_size == 0 || total_seq_len == 0) { - return Status::OK(); - } - - int threads = std::min(batch_size, max_threads_per_block); - int blocks = (batch_size + threads - 1) / threads; - - ConvertMaskToSeqlensKernel<<>>( - attn_mask_bool, - seqlens_k, - batch_size, - total_seq_len, - mask_dims, - mask_dim0, - mask_dim1, - mask_dim2, - seqlen_offset); - - return CUDA_CALL(cudaGetLastError()); -} - template __global__ void ConvertBoolMaskToAttentionBiasKernel( const bool* __restrict__ attn_mask, @@ -328,6 +190,58 @@ template Status LaunchAddBiasInPlace(float*, const float*, int64_t, int64 template Status LaunchAddBiasInPlace<__half>(__half*, const __half*, int64_t, int64_t, cudaStream_t, int); template Status LaunchAddBiasInPlace<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, int64_t, int64_t, cudaStream_t, int); +// Zero output elements for batches where seqlens_k == 0 (fully masked). +// CUTLASS MEA epilogue computes 1/s_prime where s_prime=0 → NaN for fully-masked +// batches. The unfused path produces uniform softmax weights (finite mask_filter_value, +// not -inf) so output is valid but non-zero; we still zero for Flash parity. +// Flash handles this natively with an early-exit for empty sequences. +template +__global__ void ZeroOutputForFullyMaskedBatchesKernel( + T* __restrict__ output, + const int* __restrict__ seqlens_k, + const int batch_size, + const int64_t elements_per_batch) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t total = static_cast(batch_size) * elements_per_batch; + for (; idx < total; idx += static_cast(gridDim.x) * blockDim.x) { + int b = static_cast(idx / elements_per_batch); + if (seqlens_k[b] == 0) { + output[idx] = T(0.0f); + } + } +} + +template +Status LaunchZeroOutputForFullyMaskedBatches( + T* output, + const int* seqlens_k, + int batch_size, + int64_t elements_per_batch, + cudaStream_t stream, + int max_threads_per_block) { + int64_t total = static_cast(batch_size) * elements_per_batch; + if (total == 0) { + return Status::OK(); + } + + int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); + int64_t blocks = (total + threads - 1) / threads; + constexpr int64_t kMaxGridDimX = 65535; + unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + + ZeroOutputForFullyMaskedBatchesKernel<<>>( + output, seqlens_k, batch_size, elements_per_batch); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchZeroOutputForFullyMaskedBatches( + float*, const int*, int, int64_t, cudaStream_t, int); +template Status LaunchZeroOutputForFullyMaskedBatches<__half>( + __half*, const int*, int, int64_t, cudaStream_t, int); +template Status LaunchZeroOutputForFullyMaskedBatches<__nv_bfloat16>( + __nv_bfloat16*, const int*, int, int64_t, cudaStream_t, int); + // Simple kernel to fill an int32 buffer with a constant value on device. // Used for CUDA-graph-capturable seqlens_k initialization (no host memory). __global__ void FillInt32Kernel(int* __restrict__ output, const int value, const int count) { diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index d5b497e434127..1ada783e9d64d 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -8,45 +8,6 @@ namespace onnxruntime { namespace cuda { -// Convert a boolean attention mask to sequence lengths with a configurable offset. -// -// The mask is expected to have the following properties: -// 1. It represents right-padding only (valid tokens first, padding at the end) -// 2. All-false masks (zero-length sequence) are valid; otherwise mask should start with True -// 3. True values should be contiguous, followed by contiguous False (padding) values -// 4. The mask must be broadcastable to (batch_size, num_heads, q_seq_len, total_seq_len) -// -// For 2D mask (q_seq_len, total_seq_len): broadcasts over batch; uses first query position (row 0) -// For 3D mask (num_heads, q_seq_len, total_seq_len): broadcasts across batches, uses first head/q -// For 4D mask (B, H, q_seq_len, total_seq_len): uses first head, first q position -// -// seqlen_offset adjusts the raw token count: -// seqlens_k[b] = num_true_tokens + seqlen_offset -// -// Common offsets: -// 0: total valid token count (for decode Step 4 where mha_fwd_kvcache reads from -// pre-populated cache with k_new=nullptr, and for MEA custom right padding) -// -N: subtract N from count (for decode with mha_fwd_kvcache where N=kv_sequence_length, -// giving the number of tokens already in cache BEFORE appending new ones) -// -// Note: Mask validity (right-padding convention, contiguous True/False) -// is checked via CUDA_KERNEL_ASSERT inside the kernel (debug builds only). -// In release builds, non-contiguous masks produce memory-safe but semantically incorrect output: -// seqlens_k is computed as the count of leading True values (up to the first False), -// ignoring any True values that appear after the first False. -Status LaunchConvertMaskToFlashSeqlensK( - const bool* attn_mask_bool, - int* seqlens_k, - int batch_size, - int total_seq_len, - int mask_dims, - int64_t mask_dim0, - int64_t mask_dim1, - int64_t mask_dim2, - cudaStream_t stream, - int max_threads_per_block, - int seqlen_offset = 0); - // Convert a boolean attention mask to an additive attention bias for the MHA path. // Maps true -> 0.0 (attend) and false -> mask_filter_value (mask out). // The output has the same shape as the input mask. @@ -98,6 +59,20 @@ Status LaunchAddBiasInPlace( cudaStream_t stream, int max_threads_per_block); +// Zero output elements for batches where seqlens_k == 0 (fully masked). +// Used in the MEA path only: CUTLASS epilogue computes 1/s_prime where s_prime=0, +// producing NaN for fully-masked batches. This kernel overwrites those NaN outputs +// with zeros. The unfused path produces valid finite output (mean-of-V via uniform +// softmax) and does not need this. Flash handles it natively with an early-exit. +template +Status LaunchZeroOutputForFullyMaskedBatches( + T* output, + const int* seqlens_k, + int batch_size, + int64_t elements_per_batch, + cudaStream_t stream, + int max_threads_per_block); + // Fill an int32 buffer with a constant value entirely on device. // CUDA-graph-capturable alternative to host vector + cudaMemcpyAsync. Status LaunchFillInt32(int* output, int value, int count, cudaStream_t stream, int max_threads_per_block); diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index b651a47b582ac..d287ce5da1504 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -514,9 +514,13 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { ); } -// Regression test: all-false bool mask in decode mode (past_sequence_length > 0). -// Before the fix: seq_len=0 + negative seqlen_offset produced negative seqlens_k → UB/crash. -// After the fix: clamped to max(0, ...) → uniform softmax → mean of V values. +// Regression guard: all-false bool mask in decode mode (past_sequence_length > 0). +// Guards against a bug where fully-masked batches produce NaN or incorrect output. +// Expected behavior: uniform softmax over past KV values produces Y = mean-of-V. +// With past_v = [10,20,30,40] and [20,40,60,80] per head, and all positions masked out, +// softmax(all -inf + constant mask_filter_value) → uniform weights → Y = {25, 50}. +// This test originally came from upstream/main and validates that both CPU and CUDA +// (unfused path) handle the all-false mask case identically. TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { int batch_size = 1; int q_num_heads = 2; @@ -576,9 +580,11 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f, 0.8f}; // With all-false mask, softmax produces uniform weights: 1/4 per position. - // Output = mean of V rows (past_value concat new_v): + // Standard concat places new token at past_sequence_length, so present_value = + // [past_v[0], past_v[1], past_v[2], new_v]. Output = mean of all V rows: // head 0: mean(0.1, 0.2, 0.3, 0.4) = 0.25 // head 1: mean(0.5, 0.6, 0.7, 0.8) = 0.65 + // These values match upstream/main behavior (unfused standard concat path). std::vector y = { // head 0 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, 0.25f, @@ -603,9 +609,9 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { ); } -// Flash decode path with fp16 and all-true bool attention mask. -// Exercises LaunchConcatNewToPastKV + mha_fwd_kvcache(k_new=nullptr) with mask. -// head_size=64 is Flash-eligible. Uniform keys make output analytically verifiable: +// Unfused decode path with fp16 and all-true bool attention mask. +// Flash rejects attn_mask (requires attn_mask==nullptr), so CUDA routes to unfused. +// head_size=64. Uniform keys make output analytically verifiable: // all attention scores are equal, so softmax is uniform over all positions. TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { int batch_size = 1; @@ -687,17 +693,13 @@ TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { ); } -// Flash decode path with fp16 and partial bool mask (CUDA-only, direct OpTester). -// mask [T,T,T,F] → Flash counts 3 leading trues, past_seqlens=2, new token at position 2. -// Flash attends to [past[0], past[1], new_token] → Y = mean of those 3. -// Uses direct OpTester for per-output tolerance: present_key/value position 3 may be -// uninitialized (concat kernel only fills past_seqlens+1 positions). -// CUDA-only because Flash's seqlens_k mask semantics differ from CPU's element-wise mask. +// Decode with partial bool mask [T,T,T,F]: the new token is masked out. +// With mask [T,T,T,F] past_seq=3 total=4: only positions 0,1,2 are attended (past only). +// Flash is ineligible (bool+past_key rejected), so CUDA uses unfused which handles this +// spec-correctly via standard ConcatPastToPresent + element-wise mask application. +// Y = uniform mean over the 3 attended past values (Q=K=constant → uniform softmax). +// CPU always runs; CUDA runs when SM 5.3+ is available. TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { - if (!HasCudaEnvironment(530)) { - return; // fp16 requires SM 5.3+ - } - int batch_size = 1; int q_num_heads = 2; int q_sequence_length = 1; @@ -728,21 +730,26 @@ TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { v_head_size, pv[h][s]); } - // Y: uniform 1/3 over [past_v[0], past_v[1], v_new] - // head 0: (0.1 + 0.2 + 0.4) / 3 = 7/30 - // head 1: (0.5 + 0.6 + 0.8) / 3 = 19/30 + // Y: uniform 1/3 over past values [past_v[0], past_v[1], past_v[2]] (new token masked out). + // head 0: (0.1 + 0.2 + 0.3) / 3 = 0.2 + // head 1: (0.5 + 0.6 + 0.7) / 3 = 0.6 std::vector y(batch_size * q_num_heads * q_sequence_length * v_head_size); { - float y_per_head[] = {7.0f / 30.0f, 19.0f / 30.0f}; + float y_per_head[] = {0.2f, 0.6f}; for (int h = 0; h < q_num_heads; ++h) std::fill_n(y.begin() + h * v_head_size, v_head_size, y_per_head[h]); } - // Skip present_key/value validation: trailing positions beyond valid tokens are - // intentionally uninitialized for performance (Flash respects seqlens_k bounds). - // Other tests (AllTrueMask, UnfusedPrompt) verify present_key/value for fully-valid cases. - // We must still declare these outputs (required when past_key is provided), but use - // placeholder values — the op validates shapes, not content, for optional outputs. + // present_key/value: standard concat — all past rows + new at position past_sequence_length. + std::vector present_key(batch_size * kv_num_heads * total_sequence_length * head_size, 0.5f); + std::vector present_value(batch_size * kv_num_heads * total_sequence_length * v_head_size); + { + float pv_expected[2][4] = {{0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.6f, 0.7f, 0.8f}}; + for (int h = 0; h < kv_num_heads; ++h) + for (int s = 0; s < total_sequence_length; ++s) + std::fill_n(present_value.begin() + (h * total_sequence_length + s) * v_head_size, + v_head_size, pv_expected[h][s]); + } OpTester test("Attention", 23, onnxruntime::kOnnxDomain); test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); @@ -752,92 +759,33 @@ TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { test.AddInput("past_key", {batch_size, kv_num_heads, past_sequence_length, head_size}, ToFloat16(past_key)); test.AddInput("past_value", {batch_size, kv_num_heads, past_sequence_length, v_head_size}, ToFloat16(past_value)); - // Declare all 3 outputs (required for graph construction). present_key/value use - // placeholder expected data — actual validation is done by the custom verifier below. test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, v_head_size}, ToFloat16(y)); - std::vector present_key_placeholder(batch_size * kv_num_heads * total_sequence_length * head_size, 0.0f); - std::vector present_value_placeholder(batch_size * kv_num_heads * total_sequence_length * v_head_size, 0.0f); test.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, - ToFloat16(present_key_placeholder)); + ToFloat16(present_key)); test.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, v_head_size}, - ToFloat16(present_value_placeholder)); - - // Custom verifier: validate Y and present_key/value prefix (NaN-safe for uninitialized tail). - // With mask [T,T,T,F] and seqlen_offset=-1: past_seqlens=2, valid positions=[0,3), position 3 uninitialized. - auto expected_y_fp16 = ToFloat16(y); - const int valid_seq_positions = 3; // past_seqlens(2) + kv_sequence_length(1) - test.SetCustomOutputVerifier( - [&](const std::vector& fetches, const std::string& provider_type) { - ASSERT_GE(fetches.size(), 3u) << "Expected 3 outputs, provider: " << provider_type; - - // Validate Y (output 0). - const auto& y_tensor = fetches[0].Get(); - auto y_span = y_tensor.DataAsSpan(); - ASSERT_EQ(y_span.size(), expected_y_fp16.size()) << "Y size mismatch, provider: " << provider_type; - for (size_t i = 0; i < y_span.size(); ++i) { - ASSERT_NEAR(y_span[i].ToFloat(), expected_y_fp16[i].ToFloat(), 3e-3f) - << "Y mismatch at " << i << ", provider: " << provider_type; - } + ToFloat16(present_value)); - // Validate present_key prefix (output 1): positions [0, valid_seq_positions) in BNSH layout. - // past_seqlens=2 rows from past_key, then 1 row from k. Position 3 is uninitialized (skip). - const int past_seqlens = valid_seq_positions - kv_sequence_length; // = 2 - { - auto pk_span = fetches[1].Get().DataAsSpan(); - for (int h = 0; h < kv_num_heads; ++h) { - int present_h = h * total_sequence_length * head_size; - int past_h = h * past_sequence_length * head_size; - // Check past rows [0, past_seqlens) - for (int s = 0; s < past_seqlens; ++s) - for (int d = 0; d < head_size; ++d) - ASSERT_NEAR(pk_span[present_h + s * head_size + d].ToFloat(), - past_key[past_h + s * head_size + d], 1e-3f) - << "present_key past mismatch h=" << h << " s=" << s << " d=" << d; - // Check new key row at position past_seqlens - int k_h = h * kv_sequence_length * head_size; - for (int d = 0; d < head_size; ++d) - ASSERT_NEAR(pk_span[present_h + past_seqlens * head_size + d].ToFloat(), - k[k_h + d], 1e-3f) - << "present_key new-key mismatch h=" << h << " d=" << d; - } - } - - // Validate present_value prefix (output 2): same structure with v_head_size. - { - auto pv_span = fetches[2].Get().DataAsSpan(); - for (int h = 0; h < kv_num_heads; ++h) { - int present_h = h * total_sequence_length * v_head_size; - int past_h = h * past_sequence_length * v_head_size; - for (int s = 0; s < past_seqlens; ++s) - for (int d = 0; d < v_head_size; ++d) - ASSERT_NEAR(pv_span[present_h + s * v_head_size + d].ToFloat(), - past_value[past_h + s * v_head_size + d], 1e-3f) - << "present_value past mismatch h=" << h << " s=" << s << " d=" << d; - int v_h = h * kv_sequence_length * v_head_size; - for (int d = 0; d < v_head_size; ++d) - ASSERT_NEAR(pv_span[present_h + past_seqlens * v_head_size + d].ToFloat(), - v[v_h + d], 1e-3f) - << "present_value new-value mismatch h=" << h << " d=" << d; - } - } - // Position 3 (index [h, 3, :]) intentionally not validated — uninitialized for performance. - }); + test.SetOutputAbsErr("Y", 3e-3f); + test.SetOutputAbsErr("present_key", 1e-3f); + test.SetOutputAbsErr("present_value", 1e-3f); + // CPU always runs; CUDA runs when SM 5.3+ is available for fp16. std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (HasCudaEnvironment(530)) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// Multi-batch Flash decode with per-batch partial bool masks (CUDA-only, direct OpTester). -// batch_size=2 with different masks: batch 0 has 3 leading trues (past_seqlens=2), -// batch 1 has 6 leading trues (past_seqlens=5). Tests LaunchConcatNewToPastKV with -// variable per-batch past_seq_lens and validates present_key/value per-batch. -TEST(AttentionTest, FlashAttention_Decode_PartialMask_MultiBatch_Float16) { - if (!HasCudaEnvironment(530)) { - return; // fp16 requires SM 5.3+ - } - +// Multi-batch decode with per-batch partial bool masks. +// batch_size=2: batch 0 [T,T,T,F,F,F] (3 leading trues), batch 1 [T,T,T,T,T,T] (all true). +// Flash is ineligible (bool+past_key rejected), CUDA uses unfused. +// Unfused applies standard ConcatPastToPresent (new token at position past_sequence_length=5 +// for all batches) and element-wise mask in softmax. +// Runs on both CPU and CUDA to verify cross-EP consistency. +TEST(AttentionTest, Attention4DAttnMaskBoolPartialMask_MultiBatch_Float16) { int batch_size = 2; int q_num_heads = 2; int q_sequence_length = 1; @@ -848,12 +796,10 @@ TEST(AttentionTest, FlashAttention_Decode_PartialMask_MultiBatch_Float16) { int past_sequence_length = 5; int total_sequence_length = past_sequence_length + kv_sequence_length; // 6 - // Uniform Q and K → equal attention scores → softmax is uniform. std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.5f); std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.5f); std::vector past_key(batch_size * kv_num_heads * past_sequence_length * head_size, 0.5f); - // V: [2, 2, 1, 64] — new token values per batch per head. std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size); { float v_new[2][2] = {{0.4f, 0.8f}, {0.6f, 1.0f}}; @@ -863,12 +809,12 @@ TEST(AttentionTest, FlashAttention_Decode_PartialMask_MultiBatch_Float16) { } // past_value: [2, 2, 5, 64] — distinct per-row values. - // Batch 0: only positions 0,1 are valid past (mask has 3 leading trues, past_seqlens=2). - // Batch 1: all 5 positions are valid past (mask has 6 leading trues, past_seqlens=5). + // Batch 0 mask [T,T,T,F,F,F] → 3 valid positions (all past). + // Batch 1 mask [T,T,T,T,T,T] → 6 valid positions (all past + new). std::vector past_value(batch_size * kv_num_heads * past_sequence_length * v_head_size); { float pv[2][2][5] = { - {{0.1f, 0.2f, 0.0f, 0.0f, 0.0f}, {0.5f, 0.6f, 0.0f, 0.0f, 0.0f}}, // batch 0 + {{0.1f, 0.2f, 0.3f, 0.0f, 0.0f}, {0.5f, 0.6f, 0.7f, 0.0f, 0.0f}}, // batch 0 {{0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f}} // batch 1 }; for (int b = 0; b < batch_size; ++b) @@ -879,29 +825,43 @@ TEST(AttentionTest, FlashAttention_Decode_PartialMask_MultiBatch_Float16) { v_head_size, pv[b][h][s]); } - // 4D bool mask: [2, 1, 1, 6] — per-batch varying masks. - // Batch 0: [T,T,T,F,F,F] → 3 leading trues → past_seqlens=2 - // Batch 1: [T,T,T,T,T,T] → 6 leading trues → past_seqlens=5 - // Note: use bool array instead of vector (which is bit-packed and lacks .data()). const bool mask[] = { true, true, true, false, false, false, // batch 0 true, true, true, true, true, true // batch 1 }; - // Y: uniform attention over valid positions. - // Batch 0 (3 valid): head 0: mean(0.1, 0.2, 0.4) = 7/30, head 1: mean(0.5, 0.6, 0.8) = 19/30 - // Batch 1 (6 valid): head 0: mean(0.1..0.5, 0.6) = 0.35, head 1: mean(0.5..0.9, 1.0) = 0.75 + // Y: uniform attention over valid positions (spec-correct). + // Batch 0 (3 valid, all past): head 0: mean(0.1, 0.2, 0.3) = 0.2 + // head 1: mean(0.5, 0.6, 0.7) = 0.6 + // Batch 1 (6 valid, all past + new): head 0: mean(0.1..0.5, 0.6) = 0.35 + // head 1: mean(0.5..0.9, 1.0) = 0.75 std::vector y(batch_size * q_num_heads * q_sequence_length * v_head_size); { float y_per_bh[2][2] = { - {7.0f / 30.0f, 19.0f / 30.0f}, // batch 0 - {0.35f, 0.75f} // batch 1 + {0.2f, 0.6f}, // batch 0 + {0.35f, 0.75f} // batch 1 }; for (int b = 0; b < batch_size; ++b) for (int h = 0; h < q_num_heads; ++h) std::fill_n(y.begin() + (b * q_num_heads + h) * v_head_size, v_head_size, y_per_bh[b][h]); } + // present_key/value: standard concat — all 5 past rows + new at position 5. + std::vector present_key(batch_size * kv_num_heads * total_sequence_length * head_size, 0.5f); + std::vector present_value(batch_size * kv_num_heads * total_sequence_length * v_head_size); + { + float pv_expected[2][2][6] = { + {{0.1f, 0.2f, 0.3f, 0.0f, 0.0f, 0.4f}, {0.5f, 0.6f, 0.7f, 0.0f, 0.0f, 0.8f}}, // batch 0 + {{0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f}} // batch 1 + }; + for (int b = 0; b < batch_size; ++b) + for (int h = 0; h < kv_num_heads; ++h) + for (int s = 0; s < total_sequence_length; ++s) + std::fill_n(present_value.begin() + + ((b * kv_num_heads + h) * total_sequence_length + s) * v_head_size, + v_head_size, pv_expected[b][h][s]); + } + OpTester test("Attention", 23, onnxruntime::kOnnxDomain); test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(k)); @@ -915,85 +875,27 @@ TEST(AttentionTest, FlashAttention_Decode_PartialMask_MultiBatch_Float16) { test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, v_head_size}, ToFloat16(y)); - std::vector present_key_placeholder(batch_size * kv_num_heads * total_sequence_length * head_size, 0.0f); - std::vector present_value_placeholder(batch_size * kv_num_heads * total_sequence_length * v_head_size, 0.0f); test.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, - ToFloat16(present_key_placeholder)); + ToFloat16(present_key)); test.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, v_head_size}, - ToFloat16(present_value_placeholder)); - - // Per-batch valid positions: batch 0 = 3 (past_seqlens=2 + kv_seq=1), batch 1 = 6 (all). - const int valid_positions[] = {3, 6}; - - test.SetCustomOutputVerifier( - [&](const std::vector& fetches, const std::string& provider_type) { - ASSERT_GE(fetches.size(), 3u) << "Expected 3 outputs, provider: " << provider_type; - - // Validate Y (output 0). - auto expected_y_fp16 = ToFloat16(y); - auto y_span = fetches[0].Get().DataAsSpan(); - ASSERT_EQ(y_span.size(), expected_y_fp16.size()) << "Y size mismatch, provider: " << provider_type; - for (size_t i = 0; i < y_span.size(); ++i) { - ASSERT_NEAR(y_span[i].ToFloat(), expected_y_fp16[i].ToFloat(), 3e-3f) - << "Y mismatch at " << i << ", provider: " << provider_type; - } + ToFloat16(present_value)); - // Validate present_key prefix per batch (output 1). - { - auto pk_span = fetches[1].Get().DataAsSpan(); - for (int b = 0; b < batch_size; ++b) { - int past_seqlens = valid_positions[b] - kv_sequence_length; - for (int h = 0; h < kv_num_heads; ++h) { - int present_bh = (b * kv_num_heads + h) * total_sequence_length * head_size; - int past_bh = (b * kv_num_heads + h) * past_sequence_length * head_size; - // Check past rows [0, past_seqlens) - for (int s = 0; s < past_seqlens; ++s) - for (int d = 0; d < head_size; ++d) - ASSERT_NEAR(pk_span[present_bh + s * head_size + d].ToFloat(), - past_key[past_bh + s * head_size + d], 1e-3f) - << "present_key past mismatch b=" << b << " h=" << h << " s=" << s << " d=" << d; - // Check new key at position past_seqlens - int k_bh = (b * kv_num_heads + h) * kv_sequence_length * head_size; - for (int d = 0; d < head_size; ++d) - ASSERT_NEAR(pk_span[present_bh + past_seqlens * head_size + d].ToFloat(), - k[k_bh + d], 1e-3f) - << "present_key new-key mismatch b=" << b << " h=" << h << " d=" << d; - } - } - } - - // Validate present_value prefix per batch (output 2). - { - auto pv_span = fetches[2].Get().DataAsSpan(); - for (int b = 0; b < batch_size; ++b) { - int past_seqlens = valid_positions[b] - kv_sequence_length; - for (int h = 0; h < kv_num_heads; ++h) { - int present_bh = (b * kv_num_heads + h) * total_sequence_length * v_head_size; - int past_bh = (b * kv_num_heads + h) * past_sequence_length * v_head_size; - for (int s = 0; s < past_seqlens; ++s) - for (int d = 0; d < v_head_size; ++d) - ASSERT_NEAR(pv_span[present_bh + s * v_head_size + d].ToFloat(), - past_value[past_bh + s * v_head_size + d], 1e-3f) - << "present_value past mismatch b=" << b << " h=" << h << " s=" << s << " d=" << d; - int v_bh = (b * kv_num_heads + h) * kv_sequence_length * v_head_size; - for (int d = 0; d < v_head_size; ++d) - ASSERT_NEAR(pv_span[present_bh + past_seqlens * v_head_size + d].ToFloat(), - v[v_bh + d], 1e-3f) - << "present_value new-value mismatch b=" << b << " h=" << h << " d=" << d; - } - } - } - // Uninitialized tail positions beyond valid_positions[b] per batch intentionally not validated. - }); + test.SetOutputAbsErr("Y", 3e-3f); + test.SetOutputAbsErr("present_key", 1e-3f); + test.SetOutputAbsErr("present_value", 1e-3f); + // CPU always runs; CUDA runs when SM 5.3+ is available for fp16. std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (HasCudaEnvironment(530)) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } // MEA/unfused prompt path with fp16 and bool mask (single token, no past KV cache). // past_key/past_value are absent (None); with bool mask and no past_key, CUDA routes to -// MEA/unfused (not Flash — flash_eligible requires past_key != nullptr with bool mask). +// MEA/unfused (not Flash — Flash requires attn_mask == nullptr). TEST(AttentionTest, Attention4DAttnMaskBoolUnfusedPromptFloat16) { int batch_size = 1; int q_num_heads = 2; diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index 54ec3a9111934..ceca17d6fc155 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -863,22 +863,26 @@ def test_gqa_prompt_memory_efficient(self, name, config): # flash attention. +# TODO(titaiwang): Re-enable once PR #27851 merges (MEA supports past_key for GQA). +# Flash now rejects attn_mask (requires attn_mask==nullptr). GQA + bool mask + past_key +# has no runner until MEA supports past_key. See issue #27885. +@unittest.skip( + "Flash now rejects attn_mask. GQA + bool mask + past_key has no runner " + "until PR #27851 (MEA with past_key). See issue #27885." +) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) class TestONNXAttentionPaddingMaskGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks. - Requires SM80+: decode+padding tests explicitly force Flash via - ORT_DISABLE_FLASH_ATTENTION=0. Prompt+padding uses MEA fallback and is - tested separately in TestONNXAttentionPaddingMaskMemoryEfficientGQA. + SKIPPED: Flash now requires attn_mask == nullptr. GQA + bool attn_mask + + past_key currently has no runner (Flash rejected, unfused doesn't support GQA, + MEA blocked by past_key != nullptr). Will be re-enabled when PR #27851 lands. These tests verify that the boolean attn_mask is correctly converted to sequence lengths on GPU and that the attention computation respects the padding. Tests cover 2D, 3D, and 4D mask shapes. - - Note: prompt+bool_mask is ineligible for flash (routed to MEA), so prompt - padding tests live in TestONNXAttentionPaddingMaskMemoryEfficientGQA only. """ @parameterized.expand(gqa_past_padding_test_cases()) diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index 5cb1e7b7c50b3..abe180ee35787 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -1460,10 +1460,8 @@ def test_2d_bool_mask_batch_gt_qseq(self): attn_mask[0, 4:] = False attn_mask[1, 6:] = False - # Zero out K/V at padded positions (use row 0's pattern since 2D mask broadcasts) - # For bool mask, the effective seqlen for all batches comes from row 0 (most restrictive) - # Actually for cross-attention with different masking per query, just zero out nothing - # The reference uses key_padding_mask for padding, or we can use attn_bias directly + # Build additive bias from the bool mask for the PyTorch reference. + # 2D mask broadcasts identically across batches and heads. mask_filter_value = torch.finfo(torch_type).min attn_bias_ref = torch.where( attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1), @@ -1492,5 +1490,12 @@ def test_2d_bool_mask_batch_gt_qseq(self): numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) +# NOTE: GQA fully-masked batch fix (ZeroOutputForFullyMaskedBatches) is validated by +# C++ test Attention_NonPadKVSeqLen_AllMasked_FP16_GQA. Python graph-level test omitted +# because the fix is a CUDA kernel in the MEA path — a CPU-only test cannot validate it, +# and CUDA debug builds trigger kernel assertions for zero seqlens (pre-existing issue +# with 4D mask alignment in attention_transpose.cu). + + if __name__ == "__main__": unittest.main()