Skip to content

Fix GPT-2 no-past attention fusion for transformers >= 4.27#27449

Merged
tianleiwu merged 3 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/gpt2-no-past-attention-fusion
Mar 5, 2026
Merged

Fix GPT-2 no-past attention fusion for transformers >= 4.27#27449
tianleiwu merged 3 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/gpt2-no-past-attention-fusion

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • Fix FusionGptAttentionNoPast mask pattern matching to support both torch.uint8 (old) and torch.bool (new) causal masks
  • Add synthetic ONNX graph generator and unit test for the no-past attention fusion path

Motivation

Fixes #16453

In transformers >= 4.27 (Feb 2023), the causal attention mask dtype changed from torch.uint8 to torch.bool (commit). This removed a Cast node from the exported ONNX graph. FusionGptAttentionNoPast.fuse() hardcoded Cast as the first element in match_parent_path, causing the mask path match to fail silently for all modern transformers exports. The result: zero Attention nodes fused for any GPT-2 model exported without past state.

The sibling class FusionGptAttention (with-past) was already fixed to handle both patterns using match_parent_paths (plural). This PR applies the same approach to the no-past variant.

Changes

fusion_gpt_attention_no_past.py

  • Replace match_parent_path with match_parent_paths for the Where-based mask path (lines 187-201), offering both the Cast-prefixed pattern (old transformers) and Cast-less pattern (transformers >= 4.27)
  • Remove stale TODO comment that noted the fusion "stopped working"

gpt2_model_generator.py

  • Add create_gpt2_attention_no_past() function that builds a synthetic GPT-2 no-past attention graph with the Where-based mask pattern
  • Supports add_cast parameter to test both mask variants

test_attention_fusion.py

  • Add test_gpt2_attention_no_past_fusion() that verifies an Attention node is fused for all combinations of add_cast and switch_add_inputs

Test Plan

  • New test test_gpt2_attention_no_past_fusion passes (4 variants: with/without Cast × normal/switched Add inputs)
  • All existing attention fusion tests pass (10/10)
  • Lint clean on modified files (lintrunner reports no issues for new code)

@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree

@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the review! Updated the test to use verify_fusion with golden expected model files, matching the pattern in the other tests.

Added two expected model files:

  • gpt2_attention_no_past_opt.onnx (normal add order)
  • gpt2_attention_no_past_add_opt.onnx (switched add inputs)

Both add_cast variants (with/without Cast node) produce identical fused output, so they share the same expected file. All 10 tests pass locally.

xadupre
xadupre previously approved these changes Feb 26, 2026
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu
Copy link
Copy Markdown
Contributor

@Rishi-Dave, please merge latest main and resolve the conflicts.

Rishi-Dave and others added 3 commits March 3, 2026 00:11
In transformers >= 4.27, the causal attention mask changed from
torch.uint8 to torch.bool, removing the Cast node from the ONNX graph.
FusionGptAttentionNoPast.fuse() hardcoded Cast as the first element in
the mask match path, causing fusion to silently fail for all modern
transformers exports.

Replace match_parent_path with match_parent_paths to try both the old
Cast-prefixed pattern and the new Cast-less pattern, mirroring the fix
already applied to FusionGptAttention (with-past) at lines 426-438.

Add a synthetic no-past graph generator and unit test covering both
mask variants (with and without Cast).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace manual Attention node count assertion with verify_fusion()
to match the pattern used by all other tests in the suite. Add
golden expected model files for both add-input orderings.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Rishi-Dave Rishi-Dave force-pushed the rishidave/fix/gpt2-no-past-attention-fusion branch from 65bee87 to 604aa89 Compare March 3, 2026 19:39
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Rebased onto latest main (which now includes the merged #27418 and #27489). Conflict in test_attention_fusion.py imports resolved — all 14 tests pass.

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

gpt2_model_generator.py — Test Graph Generator (+378 lines)

Minor Issues

  1. Typo: epsion (line ~944 and ~1306 in the generator). This is a known pre-existing typo for the epsilon attribute of LayerNormalization that exists throughout the file. Not introduced by this PR but worth noting.

  2. Large function: At 370+ lines, the generator is verbose, but this is consistent with the other generators in the file (create_gpt2_attention is similarly long). Acceptable.

test_attention_fusion.py — New Test (+32 lines)

Suggestion

The test saves the model to "." (current working directory) which is fragile — it depends on the CWD when tests are run. Consider using tempfile.TemporaryDirectory() for the intermediate model path, similar to how other tests in the file handle temp files. However, the existing test_gpt2_attention_fusion test uses the same pattern, so this is consistent.

@tianleiwu tianleiwu enabled auto-merge (squash) March 5, 2026 06:42
@tianleiwu tianleiwu merged commit 01a56ce into microsoft:main Mar 5, 2026
84 of 89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT2 model isn't optimized properly with transformers >=4.27

3 participants