Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,19 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions();
auto attn_mask_dims = attn_mask->Shape().GetDims();
// For 2D mask (q_seq_len, total_seq_len): both batch and heads dimensions need broadcasting
// For 3D mask (X, q_seq_len, total_seq_len): batch needs broadcasting if X==1, heads always needs broadcasting
// For 3D mask (heads_or_1, q_seq_len, total_seq_len): batch always broadcasts, heads broadcasts if dim[0]==1
// For 4D mask (B, H, q_seq_len, total_seq_len): check if B==1 and H==1

if (attn_mask_dims_size == 2) {
// 2D mask: both dimensions need broadcasting
contribop_parameters.broadcast_attn_bias_dim_0 = true;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else if (attn_mask_dims_size == 3) {
// 3D mask: dim 0 broadcasts if it's 1, dim 1 (heads) always broadcasts
contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
// 3D mask [A, q_seq_len, total_seq_len]: right-aligned to [_, A, q_seq, total_seq]
// A maps to heads dimension (validated to be 1 or q_num_heads by attention_helper.h)
// Batch dimension is missing, so always broadcasts
contribop_parameters.broadcast_attn_bias_dim_0 = true;
contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1;
} else {
// 4D mask: check both dim 0 and dim 1 explicitly
contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1;
Expand Down
Loading
Loading