diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 0ed6210fb4d29..7daed73c86631 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -546,12 +546,12 @@ Status Attention::RunFlashAttention( // ============================================================================ // // Memory Efficient Attention (cutlass FMHA) dispatch paths: -// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode -// Path 2: no past, with mask (prompt) -> standard MEA with additive bias -// Path 3: no past, no mask (prompt) -> standard MEA +// Path 1: Decode with past KV cache -> LaunchConcatNewToPastKV then standard MEA +// Path 2: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode +// Path 3: Prompt with mask -> standard MEA with additive bias +// Path 4: Prompt without mask -> standard MEA // Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype, -// head_size <= 1024), plus: no output_qk, no past_key (decode excluded), -// bias stride alignment. +// head_size <= 1024), plus: no output_qk, bias stride alignment. // Note: softcap is forwarded to the MEA kernel via p.softcap. softmax_precision // is inherently satisfied (cutlass FMHA accumulates softmax in FP32). // @@ -564,8 +564,6 @@ Status Attention::RunMemoryEfficientAttention( Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const { #if USE_MEMORY_EFFICIENT_ATTENTION - ORT_UNUSED_PARAMETER(past_key); - ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); const bool is_bsnh = parameters.transpose_output; @@ -600,6 +598,106 @@ Status Attention::RunMemoryEfficientAttention( out_data = out_bsnh_buffer.get(); } + bool present_kv_already_populated = false; + // Track the effective layout of k_data/v_data. Initially matches input layout, + // but changes to BNSH (false) after decode concat into present buffers. + bool kv_is_bsnh = is_bsnh; + + // --- Decode path: concat past + new K/V → present buffers (BNSH) --- + // nonpad_kv_seqlen and past_key are mutually exclusive (enforced at validation), + // so the decode path only needs the internal-cache (past_key/present_key) flow. + if (past_key != nullptr) { + ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); + ORT_ENFORCE(present_key != nullptr && present_value != nullptr, + "present_key/value outputs are required when past_key is provided."); + ORT_ENFORCE(parameters.head_size == parameters.v_head_size, + "MEA decode (past_key) requires head_size == v_head_size for LaunchConcatNewToPastKV."); + + using NativeCudaT = typename OrtToCudaType::type; + + // Step 1: Compute per-batch past sequence lengths for the concat kernel. + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + if (attn_mask != nullptr && attn_mask->IsDataType()) { + size_t mask_dims = attn_mask->Shape().NumDimensions(); + auto dims = attn_mask->Shape().GetDims(); + int64_t mask_dim0 = dims[0]; + int64_t mask_dim1 = mask_dims >= 3 ? dims[1] : 0; + int64_t mask_dim2 = mask_dims >= 4 ? dims[2] : 0; + // Offset -kv_seq: mask encodes total valid count; subtract to get past-only count. + int seqlen_offset = -parameters.kv_sequence_length; + ORT_RETURN_IF_ERROR(LaunchConvertMaskToFlashSeqlensK( + attn_mask->Data(), past_seqlens_buffer.get(), + parameters.batch_size, parameters.total_sequence_length, + static_cast(mask_dims), mask_dim0, mask_dim1, mask_dim2, + cuda_stream, device_prop.maxThreadsPerBlock, seqlen_offset)); + } else { + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); + } + + // Step 2: Transpose K/V to BSNH if 4D BNSH (concat kernel reads new tokens as BSNH). + const T* k_new_bsnh = K->Data(); + const T* v_new_bsnh = V->Data(); + IAllocatorUniquePtr k_bsnh_buffer; + IAllocatorUniquePtr v_bsnh_buffer; + if (!is_bsnh) { + size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.head_size; + size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.v_head_size; + k_bsnh_buffer = GetScratchBuffer(k_bytes, context->GetComputeStream()); + v_bsnh_buffer = GetScratchBuffer(v_bytes, context->GetComputeStream()); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), k_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), v_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + k_new_bsnh = static_cast(k_bsnh_buffer.get()); + v_new_bsnh = static_cast(v_bsnh_buffer.get()); + } + + // Step 3: Fused concat: past_key + new_key → present_key (BNSH). + // When bool masks produce variable per-batch past_seq_lens, positions in the range + // [past_seq_lens[b] + kv_sequence_length, total_sequence_length) are not written by + // the concat kernel. Zero the buffers first to prevent NaN propagation — MEA reads + // all positions (masked by additive bias), unlike Flash which bounds reads via seqlens_k. + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_key->MutableData(), 0, + present_key->SizeInBytes(), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(present_value->MutableData(), 0, + present_value->SizeInBytes(), cuda_stream)); + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + parameters.batch_size, + parameters.kv_num_heads, + parameters.head_size, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), + /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_key->MutableData()), + reinterpret_cast(present_value->MutableData()), + cuda_stream, + device_prop.maxThreadsPerBlock, + /*past_only=*/false)); + + // Point MEA's K/V inputs at the concatenated present buffers (BNSH). + k_data = present_key->Data(); + v_data = present_value->Data(); + kv_is_bsnh = false; + present_kv_already_populated = true; + } + // GQA head expansion: MEA requires matching num_heads for Q/K/V. // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup. const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads; @@ -640,7 +738,7 @@ Status Attention::RunMemoryEfficientAttention( reinterpret_cast(v_data), parameters.total_sequence_length, parameters.total_sequence_length, - is_bsnh, + kv_is_bsnh, cuda_stream, device_prop.maxThreadsPerBlock)); @@ -649,8 +747,8 @@ Status Attention::RunMemoryEfficientAttention( } } - // Note: MEA with past_key/value is handled by the unfused fallback. - // The cascade in ComputeInternal ensures past_key == nullptr when we reach here. + // Note: When past_key is present (decode), k_data/v_data already point to present + // buffers (BNSH) after LaunchConcatNewToPastKV above, so MEA sees the full cache. // Handle attention mask → attention_bias conversion IAllocatorUniquePtr converted_mask_buffer; @@ -683,7 +781,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -733,7 +831,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -775,30 +873,33 @@ Status Attention::RunMemoryEfficientAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // Populate present_key/present_value (BNSH) if requested - if (present_key != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.head_size, - K->Data(), present_key->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_key != nullptr && !is_bsnh) { - // 4D BNSH prompt: K is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_key->MutableData(), K->Data(), - K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); - } - if (present_value != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.v_head_size, - V->Data(), present_value->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_value != nullptr && !is_bsnh) { - // 4D BNSH prompt: V is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_value->MutableData(), V->Data(), - V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + // Populate present_key/present_value (BNSH) if requested. + // Skip for decode path where LaunchConcatNewToPastKV already populated present buffers. + if (!present_kv_already_populated) { + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } } return Status::OK(); @@ -1148,7 +1249,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - past_key == nullptr; + // MEA decode requires head_size == v_head_size for LaunchConcatNewToPastKV + // (single head_size parameter). Fall back to unfused when they differ. + !(past_key != nullptr && parameters.head_size != parameters.v_head_size); // Cutlass FMHA requires bias strides to satisfy minimum alignment even in the // "unaligned" kernel path. When an attention mask is present (with or without diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 48640fa38aca2..39989df2d51f3 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -88,6 +88,7 @@ class AttentionConfig: q_num_heads: int kv_num_heads: int head_size: int + v_head_size: int = 0 # 0 means same as head_size; set explicitly for asymmetric Q/V head sizes is_causal: int = 0 past_kv_sequence_length: int = 0 softcap: float = 0.0 @@ -135,6 +136,9 @@ def create_attention_node_and_io( else: # Prompt (no past KV cache) present_kv_seqlen = config.kv_sequence_length + # Effective v_head_size: defaults to head_size when not explicitly set + effective_v_head_size = config.v_head_size or config.head_size + if not config.kv_cache_type: config.kv_cache_type = { TensorProto.FLOAT16: "float16", @@ -199,13 +203,14 @@ def create_attention_node_and_io( helper.make_tensor_value_info( "value", ort_type, - [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, effective_v_head_size], ), ] else: # 3D inputs: [batch, seq_len, hidden_size] q_hidden_size = config.q_num_heads * config.head_size kv_hidden_size = config.kv_num_heads * config.head_size + v_hidden_size = config.kv_num_heads * effective_v_head_size graph_input = [ helper.make_tensor_value_info( "query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size] @@ -214,7 +219,7 @@ def create_attention_node_and_io( "key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] ), helper.make_tensor_value_info( - "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + "value", ort_type, [config.batch_size, config.kv_sequence_length, v_hidden_size] ), ] @@ -263,10 +268,11 @@ def create_attention_node_and_io( # Shape: [batch, num_heads, past_seq_len, head_size] (4D BNSH format) if is_past: past_k_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, config.head_size] + past_v_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, effective_v_head_size] graph_input.extend( [ helper.make_tensor_value_info("past_key", cache_ort_type, past_k_shape), - helper.make_tensor_value_info("past_value", cache_ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", cache_ort_type, past_v_shape), ] ) @@ -276,16 +282,17 @@ def create_attention_node_and_io( # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, effective_v_head_size] if config.use_4d_bnsh: - output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size] + output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size] else: - output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * effective_v_head_size] graph_output = [ helper.make_tensor_value_info("output", ort_type, output_shape), helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), - helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", cache_ort_type, output_v_shape), ] if output_qk > 0: @@ -447,24 +454,26 @@ def attention_prompt_func( bind_tensor(io_binding, "nonpad_kv_seqlen", nonpad_kv_seqlen, device, TensorProto.INT64) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape for prompt (no past) present_seqlen = config.kv_sequence_length - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] # Determine dtype for cache tensors cache_dtype = out_dtype @@ -473,8 +482,8 @@ def attention_prompt_func( else: cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -565,28 +574,30 @@ def attention_past_func( bind_tensor(io_binding, "past_value", past_v_sliced, device, cache_ort_type) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape (past + new) present_seqlen = total_seq_len - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] cache_dtype = out_dtype - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index 54ec3a9111934..6630c1a79e1c9 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -858,9 +858,41 @@ def test_gqa_prompt_memory_efficient(self, name, config): atol=atol["fp16"], ) - # Note: GQA past tests removed — MEA is ineligible when past_key is present - # (ComputeInternal requires past_key == nullptr for MEA). GQA past requires - # flash attention. + @parameterized.expand(gqa_past_test_cases()) + def test_gqa_past_memory_efficient(self, name, config): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" + + @parameterized.expand(gqa_past_test_cases()) + def test_gqa_past_memory_efficient_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") @@ -934,6 +966,134 @@ def test_gqa_prompt_padding_mea(self, name, config): atol=atol["fp16"], ) + @parameterized.expand(gqa_past_padding_test_cases()) + def test_gqa_past_padding_mea(self, name, config): + """Test decoding phase with boolean padding mask using Memory Efficient Attention.""" + past_seqlens = torch.full( + (config.batch_size,), + config.past_kv_sequence_length, + dtype=torch.int32, + device="cuda", + ) + + parity_check_gqa_past_with_padding( + config=config, + past_seqlens=past_seqlens, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQAFloatMaskDecode(unittest.TestCase): + """ + Test GQA with float additive attention mask during decode using MEA. + + This exercises the MEA decode path with float additive masks — a scenario + that was a HARD ERROR before the MEA+decode code fix (MEA was ineligible + when past_key was present, so this fell through to no kernel). + """ + + def test_gqa_past_float_mask_4d(self): + """Test GQA decode with 4D float additive mask via MEA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 8, 128, device=device, dtype=torch_type) * std + + past_k = torch.randn(2, 2, 32, 128, device=device, dtype=torch_type) * std + past_v = torch.randn_like(past_k) * std + + new_k = torch.randn(2, 1, 2, 128, device=device, dtype=torch_type) * std + new_v = torch.randn_like(new_k) * std + + total_seq_len = 33 # past(32) + new(1) + + # Create additive mask with padding pattern: batch 0 has 28 valid past, batch 1 full + past_seqlens = torch.tensor([28, 32], dtype=torch.int32, device=device) + total_seqlens = past_seqlens + config.kv_sequence_length + + attn_mask = create_additive_mask_from_seqlens( + seqlens=total_seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=1, + num_heads=8, + device=device, + dtype=torch_type, + ) + + # Zero padded past positions for batch 0 + past_k[0, :, 28:, :] = 0 + past_v[0, :, 28:, :] = 0 + + # Reference: concat past + new, then compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + # Expand 4D mask to reference attn_bias [batch, heads, q_seq, total_seq] + attn_bias_ref = attn_mask + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_ort = out_ort.reshape(2, 1, 8, 128) + + # --- Verify present_k/v match concatenated reference --- + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # --- Verify output --- + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + # ################################################################################################# # Parity Check with nonpad_kv_seqlen (Opset 24) @@ -1257,6 +1417,31 @@ def test_gqa_4d_bnsh_decode(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping 4D BNSH tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionGQA4DBNSHMEA(unittest.TestCase): + """ + Test GQA with 4D BNSH input format via Memory Efficient Attention. + + Verifies the BNSH transpose logic (use_4d_bnsh=True) works correctly + when MEA handles the decode path. The C++ attention op detects 4D inputs + and sets transpose_output=false; the dispatcher transposes Q internally. + """ + + @parameterized.expand(gqa_4d_bnsh_past_test_cases()) + def test_gqa_4d_bnsh_decode_mea(self, name, config): + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # ################################################################################################# # GQA Float Additive Mask Tests # ################################################################################################# diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index 5cb1e7b7c50b3..da1298b466cd7 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -866,6 +866,29 @@ def test_mha_past_fp32(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEA(unittest.TestCase): + """Test ONNX Attention op MHA path — decoding with KV cache via Memory Efficient Attention. + + Explicitly forces MEA by disabling Flash Attention. This verifies that the + MEA decode path works correctly for MHA (kv_num_heads == q_num_heads). + """ + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea(self, name, config): + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") class TestONNXAttentionMHAAttnBias(unittest.TestCase): """ @@ -1251,10 +1274,98 @@ def test_mha_unfused_fp16(self, name, config): # ################################################################################################# -# Broadcast Mask (1,1,q,kv) Tests +# Asymmetric Head Size Regression Test (MEA → unfused fallback) # ################################################################################################# +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping asymmetric head size tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAAsymmetricHeadSize(unittest.TestCase): + """ + Regression test: MEA gracefully falls back to unfused when head_size != v_head_size + with past_key present (decode phase). + + Without the eligibility guard in ComputeInternal, this configuration would select + MEA which then crashes with ORT_ENFORCE because LaunchConcatNewToPastKV requires + head_size == v_head_size. The guard skips MEA and falls back to unfused attention. + + Uses MHA path (kv_num_heads == q_num_heads) because the GQA path has no unfused + fallback (returns NOT_IMPLEMENTED). + """ + + def test_mha_past_asymmetric_v_head_size(self): + """Verify decode with head_size=128, v_head_size=96 doesn't crash (falls to unfused).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=128, + v_head_size=96, + is_causal=1, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + + # Past KV in BNSH: K uses head_size=128, V uses v_head_size=96 + past_k = torch.randn(2, 4, 32, 128, device=device, dtype=torch_type) * std + past_v = torch.randn(2, 4, 32, 96, device=device, dtype=torch_type) * std + + new_k = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + new_v = torch.randn(2, 1, 4, 96, device=device, dtype=torch_type) * std + + # PyTorch reference: concat past + new, compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + + # ORT path — should fall back to unfused (not crash in MEA) + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=None, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + # Reshape output: [B, q_seq, q_num_heads * v_head_size] → [B, q_seq, q_num_heads, v_head_size] + out_ort = out_ort.reshape(2, 1, 4, 96) + + # Verify present_k and present_v + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # Verify output + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + @unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") class TestONNXAttentionMHABroadcastMask(unittest.TestCase): """