Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <vector>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/attention.h"
#include "core/providers/cpu/llm/attention_helper.h"
#include "core/providers/cuda/llm/attention.h"
#include "core/providers/cuda/llm/attention_mask_impl.h"
Expand Down Expand Up @@ -376,6 +377,10 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// GQA only supports masking, not additive bias.
// For bool mask, we need to convert it to sequence lengths on GPU.
// Note: The GQA path interprets 2D bool masks as (batch_size, total_seq_len) since it converts
// masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from
// the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
// 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
// Allocate validation result buffer on GPU
auto validation_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());
Expand Down Expand Up @@ -511,15 +516,18 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE;

// Determine broadcast flags for attention_bias (if it exists)
// Note: The new Attention op uses attn_mask as attention_bias
// The attention_bias should be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length)
// attn_mask can be 2D, 3D, or 4D. Broadcasting aligns from the right (trailing dimensions).
// The MHA path uses attn_mask as attention_bias (additive bias added before softmax).
// Bool masks are element-wise converted to additive bias (true → 0.0, false → -inf),
// preserving the original shape, so the same broadcasting rules apply to both types.
//
// ONNX broadcasting is right-aligned to target shape (batch, heads, q_seq, total_seq):
// 2D [A, B] → [1, 1, A, B] : A = q_seq_len, B = total_seq_len
// 3D [A, B, C] → [1, A, B, C] : A = heads, B = q_seq_len, C = total_seq_len
// 4D [A, B, C, D] → [A, B, C, D] : A = batch, B = heads, C = q_seq_len, D = total_seq_len
//
// Note: A 2D mask cannot represent per-batch padding because the batch dimension is broadcast.
// For per-batch boolean padding masks, use 4D shape (batch, 1, 1, total_seq_len).
if (attn_mask != nullptr) {
// TODO(titaiwang, xadupre): attn_mask bool is not supported yet
if (attn_mask->IsDataType<bool>()) {
ORT_THROW("Boolean attn_mask is not supported yet in Attention op (CUDA).");
}

size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
auto attn_mask_dims = attn_mask->Shape().GetDims();
// For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
Expand All @@ -546,7 +554,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
contribop_parameters.broadcast_attn_bias_dim_1 = false;
}

contribop_parameters.mask_filter_value = -10000.0f;
contribop_parameters.mask_filter_value = static_cast<float>(std::numeric_limits<T>::lowest());
contribop_parameters.scale = parameters.scale;
contribop_parameters.use_tf32 = UseTF32();
// TODO(titaiwang, xadupre): qk_matmul_output_mode only supports kNone and kQK for now
Expand Down Expand Up @@ -584,8 +592,27 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

// Set additional fields
data.bias = nullptr; // New Attention op doesn't have bias
IAllocatorUniquePtr<void> converted_mask_buffer;
if (nullptr != attn_mask) {
data.attention_bias = reinterpret_cast<const CudaT*>(attn_mask->Data<T>());
if (attn_mask->IsDataType<bool>()) {
// Convert boolean mask to additive attention bias: true -> 0.0, false -> mask_filter_value.
// The conversion is element-wise and preserves the original shape, so the broadcast flags
// set above apply identically to the converted float buffer.
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t num_elements = attn_mask->Shape().Size();
converted_mask_buffer = GetScratchBuffer<void>(num_elements * sizeof(NativeCudaT), context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias<NativeCudaT>(
attn_mask->Data<bool>(),
reinterpret_cast<NativeCudaT*>(converted_mask_buffer.get()),
num_elements,
contribop_parameters.mask_filter_value,
Comment thread
titaiwangms marked this conversation as resolved.
cuda_stream,
GetDeviceProp().maxThreadsPerBlock));
data.attention_bias = reinterpret_cast<const CudaT*>(converted_mask_buffer.get());
} else {
data.attention_bias = reinterpret_cast<const CudaT*>(attn_mask->Data<T>());
}
}
data.qkv_format = contribop_parameters.qkv_format;

Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,50 @@
return CUDA_CALL(cudaGetLastError());
}

template <typename T>
__global__ void ConvertBoolMaskToAttentionBiasKernel(
const bool* __restrict__ attn_mask,
T* __restrict__ attention_bias,
const int64_t num_elements,
const float mask_filter_value) {
for (int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
idx < num_elements;
idx += static_cast<int64_t>(gridDim.x) * blockDim.x) {
attention_bias[idx] = attn_mask[idx] ? T(0.0f) : T(mask_filter_value);
}
}

template <typename T>
Status LaunchConvertBoolMaskToAttentionBias(
const bool* attn_mask_bool,
T* attention_bias,
int64_t num_elements,
float mask_filter_value,
cudaStream_t stream,
int max_threads_per_block) {
if (num_elements == 0) {
return Status::OK();
}

int threads = static_cast<int>(std::min(static_cast<int64_t>(max_threads_per_block), num_elements));
int64_t blocks = (num_elements + threads - 1) / threads;
// Cap grid size to avoid exceeding CUDA gridDim.x limit (2^31 - 1).
// The grid-stride loop in the kernel handles the overflow.
constexpr int64_t kMaxGridDimX = 65535;
unsigned int grid_size = static_cast<unsigned int>(std::min(blocks, kMaxGridDimX));

Check warning on line 178 in onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:178: Add #include <algorithm> for min [build/include_what_you_use] [4]

ConvertBoolMaskToAttentionBiasKernel<T><<<grid_size, threads, 0, stream>>>(
attn_mask_bool, attention_bias, num_elements, mask_filter_value);

return CUDA_CALL(cudaGetLastError());
}

template Status LaunchConvertBoolMaskToAttentionBias<float>(
const bool*, float*, int64_t, float, cudaStream_t, int);
template Status LaunchConvertBoolMaskToAttentionBias<__half>(
const bool*, __half*, int64_t, float, cudaStream_t, int);
template Status LaunchConvertBoolMaskToAttentionBias<__nv_bfloat16>(
const bool*, __nv_bfloat16*, int64_t, float, cudaStream_t, int);

} // namespace cuda
} // namespace onnxruntime
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,17 @@ Status LaunchConvertMaskToSeqlensK(
cudaStream_t stream,
int max_threads_per_block);

// Convert a boolean attention mask to an additive attention bias for the MHA path.
// Maps true -> 0.0 (attend) and false -> mask_filter_value (mask out).
// The output has the same shape as the input mask.
template <typename T>
Status LaunchConvertBoolMaskToAttentionBias(
const bool* attn_mask_bool,
T* attention_bias,
int64_t num_elements,
float mask_filter_value,
cudaStream_t stream,
int max_threads_per_block);

} // namespace cuda
} // namespace onnxruntime
7 changes: 5 additions & 2 deletions onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,10 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) {
q, k, v, std::vector<float>(), m, std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
false, true, true // disable_cpu, disable_cuda, disable_dml
// Note: all-false bool mask (every position masked) is a degenerate case. It works because
// mask_filter_value (~-3.4e38) is so extreme that float precision loses QK differences,
// producing uniform softmax weights matching CPU behavior.
false, false, true // disable_cpu, disable_cuda, disable_dml
);
}

Expand Down Expand Up @@ -617,7 +620,7 @@ TEST(AttentionTest, Attention4DAttnMaskBool) {
q, k, v, std::vector<float>(), m, std::vector<float>(), std::vector<float>(),
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
false, true, true // disable_cpu, disable_cuda, disable_dml
false, false, true // disable_cpu, disable_cuda, disable_dml
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,20 @@ def create_attention_node_and_io(
else:
mask_ort_type = ort_type # additive mask uses same type as Q/K/V

# Mask shapes differ between GQA (bool) and MHA (additive) paths:
# GQA bool: 2D=[batch, total_seq], 3D=[heads, q_seq, total_seq], 4D=[batch, heads, q_seq, total_seq]
# MHA additive: 2D=[q_seq, total_seq], 3D=[heads, q_seq, total_seq], 4D=[batch, heads, q_seq, total_seq]
# Mask shapes differ between GQA (bool) and MHA (additive/bool) paths:
# GQA bool: 2D=[batch, total_seq] — GQA converts to seqlens_k directly, bypassing ONNX broadcasting.
# MHA (additive or bool): 2D=[q_seq, total_seq] — follows ONNX right-aligned broadcasting.
# 3D and 4D are the same for both paths.
# ONNX broadcasting aligns from the right: 3D [A, B, C] → [_, A, B, C] where A=heads
if config.attn_mask_type == "bool":
is_gqa = config.kv_num_heads != config.q_num_heads
if config.attn_mask_type == "bool" and is_gqa:
if config.attn_mask_dims == 2:
mask_shape = [config.batch_size, mask_seq_len]
elif config.attn_mask_dims == 3:
mask_shape = [config.q_num_heads, config.q_sequence_length, mask_seq_len]
else: # 4D
mask_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, mask_seq_len]
else: # additive
else: # additive, or bool on MHA path
if config.attn_mask_dims == 2:
mask_shape = [config.q_sequence_length, mask_seq_len]
elif config.attn_mask_dims == 3:
Expand Down
Loading
Loading