Skip to content

Fix 3d attention mask broadcasting in MHA#27464

Merged
xadupre merged 3 commits intomainfrom
titaiwang/reorg_onnx_attention_test
Feb 26, 2026
Merged

Fix 3d attention mask broadcasting in MHA#27464
xadupre merged 3 commits intomainfrom
titaiwang/reorg_onnx_attention_test

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

(1) Fix 3d attention mask broadcasting in MHA
(2) Refactor attention python tests of LLM (add MHA)


This pull request includes a minor fix to the attention mask broadcasting logic in the CUDA attention kernel, as well as the addition of a missing license header in a test file.

Improvements to attention mask broadcasting logic:

  • Updated the logic in attention.cc to clarify and correct how broadcasting is determined for 3D attention masks, ensuring the batch dimension always broadcasts and the heads dimension broadcasts only if its size is 1. This improves correctness and clarity for different mask shapes.

Documentation and compliance:

  • Added the standard Microsoft MIT license header to the test_onnx_attention/__init__.py file to ensure proper licensing information is included.

titaiwangms and others added 3 commits February 25, 2026 23:56
Split the monolithic test_onnx_attention.py (1651 lines) into a
test_onnx_attention/ package with focused modules:

- common.py: Shared config (AttentionConfig), ONNX graph builders,
  reference attention implementation, mask helpers, and utilities
- test_gqa.py: GQA path tests (kv_num_heads != q_num_heads) covering
  Flash Attention, Memory Efficient Attention, BF16, and padding masks
- test_mha.py: New MHA path tests (kv_num_heads == q_num_heads) covering:
  - Causal self-attention prompt (fp16, fp32, bf16)
  - Non-causal self-attention (encoder models)
  - Cross-attention (encoder-decoder, q_seq != kv_seq)
  - Decoding with KV cache (fp16, fp32)
  - Additive attention bias (2D, 3D, 4D masks with broadcasting)

The MHA tests exercise the unfused attention path in attention.cc
(lines 486-631), which was previously untested. Added support for
additive mask type (vs boolean for GQA) with correct shape handling
per the broadcasting rules in the CUDA implementation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The 3D mask broadcast flags were swapped in the MHA path. ONNX
broadcasting aligns from the right, so a 3D mask [A, q_seq, total_seq]
maps to [_, A, q_seq, total_seq] where A is the heads dimension
(validated by attention_helper.h to be 1 or q_num_heads).

The old code incorrectly treated dim[0] as batch:
  broadcast_attn_bias_dim_0 = (dim[0] == 1)  // wrong: batch
  broadcast_attn_bias_dim_1 = true            // wrong: heads

Fixed to:
  broadcast_attn_bias_dim_0 = true                 // batch always broadcasts for 3D
  broadcast_attn_bias_dim_1 = (dim[0] == 1)        // heads broadcasts only if dim[0]==1

This caused the kernel to use the batch index as a stride into the
heads dimension, reading wrong data for batch > 0.

Also updates test mask shapes, reference bias expansion, and fp32
tolerances (TF32 enabled by default on Ampere+ GPUs).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@xadupre xadupre changed the title Fix Fix 3d attention mask broadcasting in MHA Feb 26, 2026
@xadupre xadupre enabled auto-merge (squash) February 26, 2026 16:52
@xadupre xadupre merged commit 654c335 into main Feb 26, 2026
91 checks passed
@xadupre xadupre deleted the titaiwang/reorg_onnx_attention_test branch February 26, 2026 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants