Fix 3d attention mask broadcasting in MHA#27464
Merged
Conversation
Split the monolithic test_onnx_attention.py (1651 lines) into a test_onnx_attention/ package with focused modules: - common.py: Shared config (AttentionConfig), ONNX graph builders, reference attention implementation, mask helpers, and utilities - test_gqa.py: GQA path tests (kv_num_heads != q_num_heads) covering Flash Attention, Memory Efficient Attention, BF16, and padding masks - test_mha.py: New MHA path tests (kv_num_heads == q_num_heads) covering: - Causal self-attention prompt (fp16, fp32, bf16) - Non-causal self-attention (encoder models) - Cross-attention (encoder-decoder, q_seq != kv_seq) - Decoding with KV cache (fp16, fp32) - Additive attention bias (2D, 3D, 4D masks with broadcasting) The MHA tests exercise the unfused attention path in attention.cc (lines 486-631), which was previously untested. Added support for additive mask type (vs boolean for GQA) with correct shape handling per the broadcasting rules in the CUDA implementation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 3D mask broadcast flags were swapped in the MHA path. ONNX broadcasting aligns from the right, so a 3D mask [A, q_seq, total_seq] maps to [_, A, q_seq, total_seq] where A is the heads dimension (validated by attention_helper.h to be 1 or q_num_heads). The old code incorrectly treated dim[0] as batch: broadcast_attn_bias_dim_0 = (dim[0] == 1) // wrong: batch broadcast_attn_bias_dim_1 = true // wrong: heads Fixed to: broadcast_attn_bias_dim_0 = true // batch always broadcasts for 3D broadcast_attn_bias_dim_1 = (dim[0] == 1) // heads broadcasts only if dim[0]==1 This caused the kernel to use the batch index as a stride into the heads dimension, reading wrong data for batch > 0. Also updates test mask shapes, reference bias expansion, and fp32 tolerances (TF32 enabled by default on Ampere+ GPUs). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
xadupre
reviewed
Feb 26, 2026
xadupre
approved these changes
Feb 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
(1) Fix 3d attention mask broadcasting in MHA
(2) Refactor attention python tests of LLM (add MHA)
This pull request includes a minor fix to the attention mask broadcasting logic in the CUDA attention kernel, as well as the addition of a missing license header in a test file.
Improvements to attention mask broadcasting logic:
attention.ccto clarify and correct how broadcasting is determined for 3D attention masks, ensuring the batch dimension always broadcasts and the heads dimension broadcasts only if its size is 1. This improves correctness and clarity for different mask shapes.Documentation and compliance:
test_onnx_attention/__init__.pyfile to ensure proper licensing information is included.