diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index c98b9d555896a..618bfcd723073 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -3,6 +3,7 @@ #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" #include "core/providers/cuda/llm/attention.h" #include "core/providers/cuda/llm/attention_mask_impl.h" @@ -376,6 +377,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // GQA only supports masking, not additive bias. // For bool mask, we need to convert it to sequence lengths on GPU. + // Note: The GQA path interprets 2D bool masks as (batch_size, total_seq_len) since it converts + // masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from + // the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so + // 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast. if (attn_mask != nullptr && attn_mask->IsDataType()) { // Allocate validation result buffer on GPU auto validation_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); @@ -511,15 +516,18 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; // Determine broadcast flags for attention_bias (if it exists) - // Note: The new Attention op uses attn_mask as attention_bias - // The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) - // attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions). + // The MHA path uses attn_mask as attention_bias (additive bias added before softmax). + // Bool masks are element-wise converted to additive bias (true → 0.0, false → -inf), + // preserving the original shape, so the same broadcasting rules apply to both types. + // + // ONNX broadcasting is right-aligned to target shape (batch, heads, q_seq, total_seq): + // 2D [A, B] → [1, 1, A, B] : A = q_seq_len, B = total_seq_len + // 3D [A, B, C] → [1, A, B, C] : A = heads, B = q_seq_len, C = total_seq_len + // 4D [A, B, C, D] → [A, B, C, D] : A = batch, B = heads, C = q_seq_len, D = total_seq_len + // + // Note: A 2D mask cannot represent per-batch padding because the batch dimension is broadcast. + // For per-batch boolean padding masks, use 4D shape (batch, 1, 1, total_seq_len). if (attn_mask != nullptr) { - // TODO(titaiwang, xadupre): attn_mask bool is not supported yet - if (attn_mask->IsDataType()) { - ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA)."); - } - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); auto attn_mask_dims = attn_mask->Shape().GetDims(); // For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting @@ -546,7 +554,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { contribop_parameters.broadcast_attn_bias_dim_1 = false; } - contribop_parameters.mask_filter_value = -10000.0f; + contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); contribop_parameters.scale = parameters.scale; contribop_parameters.use_tf32 = UseTF32(); // TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now @@ -584,8 +592,27 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Set additional fields data.bias = nullptr; // New Attention op doesn't have bias + IAllocatorUniquePtr converted_mask_buffer; if (nullptr != attn_mask) { - data.attention_bias = reinterpret_cast(attn_mask->Data()); + if (attn_mask->IsDataType()) { + // Convert boolean mask to additive attention bias: true -> 0.0, false -> mask_filter_value. + // The conversion is element-wise and preserves the original shape, so the broadcast flags + // set above apply identically to the converted float buffer. + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + int64_t num_elements = attn_mask->Shape().Size(); + converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( + attn_mask->Data(), + reinterpret_cast(converted_mask_buffer.get()), + num_elements, + contribop_parameters.mask_filter_value, + cuda_stream, + GetDeviceProp().maxThreadsPerBlock)); + data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); + } else { + data.attention_bias = reinterpret_cast(attn_mask->Data()); + } } data.qkv_format = contribop_parameters.qkv_format; diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index f215f60d74288..7e1a7f47ffc70 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -145,5 +145,50 @@ Status LaunchConvertMaskToSeqlensK( return CUDA_CALL(cudaGetLastError()); } +template +__global__ void ConvertBoolMaskToAttentionBiasKernel( + const bool* __restrict__ attn_mask, + T* __restrict__ attention_bias, + const int64_t num_elements, + const float mask_filter_value) { + for (int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < num_elements; + idx += static_cast(gridDim.x) * blockDim.x) { + attention_bias[idx] = attn_mask[idx] ? T(0.0f) : T(mask_filter_value); + } +} + +template +Status LaunchConvertBoolMaskToAttentionBias( + const bool* attn_mask_bool, + T* attention_bias, + int64_t num_elements, + float mask_filter_value, + cudaStream_t stream, + int max_threads_per_block) { + if (num_elements == 0) { + return Status::OK(); + } + + int threads = static_cast(std::min(static_cast(max_threads_per_block), num_elements)); + int64_t blocks = (num_elements + threads - 1) / threads; + // Cap grid size to avoid exceeding CUDA gridDim.x limit (2^31 - 1). + // The grid-stride loop in the kernel handles the overflow. + constexpr int64_t kMaxGridDimX = 65535; + unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + + ConvertBoolMaskToAttentionBiasKernel<<>>( + attn_mask_bool, attention_bias, num_elements, mask_filter_value); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchConvertBoolMaskToAttentionBias( + const bool*, float*, int64_t, float, cudaStream_t, int); +template Status LaunchConvertBoolMaskToAttentionBias<__half>( + const bool*, __half*, int64_t, float, cudaStream_t, int); +template Status LaunchConvertBoolMaskToAttentionBias<__nv_bfloat16>( + const bool*, __nv_bfloat16*, int64_t, float, cudaStream_t, int); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index 3754143bd6363..004d6119e4cd4 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -52,5 +52,17 @@ Status LaunchConvertMaskToSeqlensK( cudaStream_t stream, int max_threads_per_block); +// Convert a boolean attention mask to an additive attention bias for the MHA path. +// Maps true -> 0.0 (attend) and false -> mask_filter_value (mask out). +// The output has the same shape as the input mask. +template +Status LaunchConvertBoolMaskToAttentionBias( + const bool* attn_mask_bool, + T* attention_bias, + int64_t num_elements, + float mask_filter_value, + cudaStream_t stream, + int max_threads_per_block); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 86bbfb172c7ee..b0c6c6d801c4b 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -474,7 +474,10 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { q, k, v, std::vector(), m, std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + // Note: all-false bool mask (every position masked) is a degenerate case. It works because + // mask_filter_value (~-3.4e38) is so extreme that float precision loses QK differences, + // producing uniform softmax weights matching CPU behavior. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -617,7 +620,7 @@ TEST(AttentionTest, Attention4DAttnMaskBool) { q, k, v, std::vector(), m, std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, std::vector(), std::vector(), std::vector(), - false, true, true // disable_cpu, disable_cuda, disable_dml + false, false, true // disable_cpu, disable_cuda, disable_dml ); } diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index e94c8c9034337..10a38329549a8 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -205,18 +205,20 @@ def create_attention_node_and_io( else: mask_ort_type = ort_type # additive mask uses same type as Q/K/V - # Mask shapes differ between GQA (bool) and MHA (additive) paths: - # GQA bool: 2D=[batch, total_seq], 3D=[heads, q_seq, total_seq], 4D=[batch, heads, q_seq, total_seq] - # MHA additive: 2D=[q_seq, total_seq], 3D=[heads, q_seq, total_seq], 4D=[batch, heads, q_seq, total_seq] + # Mask shapes differ between GQA (bool) and MHA (additive/bool) paths: + # GQA bool: 2D=[batch, total_seq] — GQA converts to seqlens_k directly, bypassing ONNX broadcasting. + # MHA (additive or bool): 2D=[q_seq, total_seq] — follows ONNX right-aligned broadcasting. + # 3D and 4D are the same for both paths. # ONNX broadcasting aligns from the right: 3D [A, B, C] → [_, A, B, C] where A=heads - if config.attn_mask_type == "bool": + is_gqa = config.kv_num_heads != config.q_num_heads + if config.attn_mask_type == "bool" and is_gqa: if config.attn_mask_dims == 2: mask_shape = [config.batch_size, mask_seq_len] elif config.attn_mask_dims == 3: mask_shape = [config.q_num_heads, config.q_sequence_length, mask_seq_len] else: # 4D mask_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, mask_seq_len] - else: # additive + else: # additive, or bool on MHA path if config.attn_mask_dims == 2: mask_shape = [config.q_sequence_length, mask_seq_len] elif config.attn_mask_dims == 3: diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index 59367ffbf5f54..daa644f40ff41 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -32,6 +32,7 @@ attention_prompt_func, attention_ref, create_additive_mask_from_seqlens, + create_boolean_mask_from_seqlens, enable_deterministic_check, has_cuda_device, pipeline_mode, @@ -559,6 +560,169 @@ def mha_attn_bias_test_cases(): yield name, config +def mha_bool_mask_test_cases(): + """ + Generate test cases for MHA path with boolean attention mask. + + Tests 2D, 3D, and 4D boolean masks for right-padding scenarios. + The MHA path in attention.cc converts bool masks to additive bias + (True -> 0.0, False -> mask_filter_value). + + For the MHA path, ONNX right-aligned broadcasting maps: + 2D [q_seq, total_seq] → [1, 1, q_seq, total_seq] (all batches share one mask) + 3D [heads, q_seq, total_seq] → [1, heads, q_seq, total_seq] + 4D [batch, heads, q_seq, total_seq] → per-batch, per-head masks + """ + batches = [2] + seqs = [(16, 16)] + heads = [(8, 8)] + h_sizes = [128] + mask_dims_options = [2, 3, 4] + + for h in h_sizes: + for b in batches: + for sq, skv in seqs: + for n, n2 in heads: + for mask_dims in mask_dims_options: + config = AttentionConfig( + batch_size=b, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + q_num_heads=n, + kv_num_heads=n2, + head_size=h, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=mask_dims, + attn_mask_type="bool", + ) + name = f"b{b}_sq{sq}_skv{skv}_nh{n}_h{h}_bool{mask_dims}d" + yield name, config + + +def parity_check_mha_prompt_with_bool_mask( + config: AttentionConfig, + seqlens: torch.Tensor, + ep, + device, + torch_type, + ort_type, + rtol, + atol, + std=0.2, +): + """ + Parity check for ONNX Attention op MHA path with boolean attention mask. + + The MHA path converts bool masks to additive bias (True -> 0.0, False -> -inf). + Tests 2D, 3D, and 4D boolean masks with padding simulation. + """ + torch.manual_seed(0) + + # Compute effective per-batch seqlens based on mask broadcasting. + # For 2D bool mask [q_seq, total_seq]: all batches share the same mask (first batch's pattern). + # For 3D bool mask [heads, q_seq, total_seq]: batch broadcasts, use first batch's pattern. + # For 4D bool mask [batch, heads, q_seq, total_seq]: per-batch seqlens apply directly. + effective_seqlens = seqlens.clone() + if config.attn_mask_dims in (2, 3): + effective_seqlens[:] = seqlens[0] + + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) * std + + # Zero out padded positions in K, V based on effective seqlens + for b in range(config.batch_size): + valid_len = effective_seqlens[b].item() + if valid_len < config.kv_sequence_length: + k[b, valid_len:, :, :] = 0 + v[b, valid_len:, :, :] = 0 + + # Create boolean mask for ORT. + # For the MHA path, 2D bool mask shape is [q_seq, total_seq] per ONNX broadcasting rules, + # so we build it from the first batch's seqlen (all batches share the same mask). + if config.attn_mask_dims == 2: + # 2D: [q_seq, total_seq] — single mask pattern for all batches + arange = torch.arange(config.kv_sequence_length, device=device) + mask_1d = arange < seqlens[0] # [total_seq] + attn_mask = mask_1d.unsqueeze(0).expand(config.q_sequence_length, -1).contiguous() # [q_seq, total_seq] + else: + attn_mask = create_boolean_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=config.kv_sequence_length, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + ) + + # Create 2D key_padding_mask for reference (per-batch, shape [batch, total_seq]) + key_padding_mask = create_boolean_mask_from_seqlens( + seqlens=effective_seqlens, + total_seq_len=config.kv_sequence_length, + mask_dims=2, + device=device, + ) + + # --- PyTorch Reference Path --- + out_ref, _ = attention_ref( + q=q, + k=k, + v=v, + key_padding_mask=key_padding_mask, + causal=config.is_causal == 1, + ) + + # --- ONNX Runtime Path --- + out, present_k, present_v = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep=ep, + device=device, + ort_type=ort_type, + ) + + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + + # --- Comparison --- + # Zero out padded positions in both outputs based on effective seqlens + for b in range(config.batch_size): + valid_len = effective_seqlens[b].item() + if valid_len < config.q_sequence_length: + out[b, valid_len:, :, :] = 0 + out_ref[b, valid_len:, :, :] = 0 + + out_np = out.to(torch.float32).detach().cpu().numpy() + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + # ################################################################################################# # Unit Test Classes # ################################################################################################# @@ -726,5 +890,35 @@ def test_mha_attn_bias_fp16(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") +class TestONNXAttentionMHABoolMask(unittest.TestCase): + """ + Test ONNX Attention op MHA path with boolean attention mask. + + Tests 2D, 3D, and 4D boolean masks that are converted to additive bias + (True -> 0.0, False -> mask_filter_value) in attention.cc. This exercises + the LaunchConvertBoolMaskToAttentionBias kernel for the MHA path. + """ + + @parameterized.expand(mha_bool_mask_test_cases()) + def test_mha_bool_mask_fp16(self, name, config): + seqlens = torch.tensor( + [config.kv_sequence_length - 6, config.kv_sequence_length], + dtype=torch.int32, + device="cuda", + ) + + parity_check_mha_prompt_with_bool_mask( + config=config, + seqlens=seqlens, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + if __name__ == "__main__": unittest.main()